hugr_core/std_extensions/
logic.rs1use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::{ConstFold, ConstFoldResult};
8use crate::ops::constant::ValueName;
9use crate::ops::{OpName, Value};
10use crate::types::Signature;
11use crate::{
12 Extension, IncomingPort,
13 extension::{
14 ExtensionId, OpDef, SignatureFunc,
15 prelude::bool_t,
16 simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name},
17 },
18 ops,
19 types::type_param::TypeArg,
20 utils::sorted_consts,
21};
22pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
24pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");
26
27impl ConstFold for LogicOp {
28 fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
29 match self {
30 Self::And => {
31 let inps = read_inputs(consts)?;
32 let res = inps.iter().all(|x| *x);
33 (!res || inps.len() as u64 == 2)
35 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
36 }
37 Self::Or => {
38 let inps = read_inputs(consts)?;
39 let res = inps.iter().any(|x| *x);
40 (res || inps.len() as u64 == 2)
42 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
43 }
44 Self::Eq => {
45 let inps = read_inputs(consts)?;
46 let res = inps.iter().copied().reduce(|a, b| a == b)?;
47 (!res || inps.len() as u64 == 2)
49 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
50 }
51 Self::Not => {
52 let inps = read_inputs(consts)?;
53 let res = inps.iter().all(|x| !*x);
54 (!res || inps.len() as u64 == 1)
55 .then_some(vec![(0.into(), ops::Value::from_bool(res))])
56 }
57 Self::Xor => {
58 let inps = read_inputs(consts)?;
59 let res = inps.iter().fold(false, |acc, x| acc ^ *x);
60 (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
61 }
62 }
63 }
64}
65
66#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
68#[allow(missing_docs)]
69#[non_exhaustive]
70pub enum LogicOp {
71 And,
72 Or,
73 Eq,
74 Not,
75 Xor,
76}
77
78impl MakeOpDef for LogicOp {
79 fn opdef_id(&self) -> OpName {
80 <&'static str>::from(self).into()
81 }
82
83 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
84 match self {
85 LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
86 Signature::new(vec![bool_t(); 2], vec![bool_t()])
87 }
88 LogicOp::Not => Signature::new_endo(vec![bool_t()]),
89 }
90 .into()
91 }
92
93 fn extension_ref(&self) -> Weak<Extension> {
94 Arc::downgrade(&EXTENSION)
95 }
96
97 fn description(&self) -> String {
98 match self {
99 LogicOp::And => "logical 'and'",
100 LogicOp::Or => "logical 'or'",
101 LogicOp::Eq => "test if bools are equal",
102 LogicOp::Not => "logical 'not'",
103 LogicOp::Xor => "logical 'xor'",
104 }
105 .to_string()
106 }
107
108 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
109 try_from_name(op_def.name(), op_def.extension_id())
110 }
111
112 fn extension(&self) -> ExtensionId {
113 EXTENSION_ID.clone()
114 }
115
116 fn post_opdef(&self, def: &mut OpDef) {
117 def.set_constant_folder(*self);
118 }
119}
120
121pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");
123pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
125
126fn extension() -> Arc<Extension> {
128 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
129 LogicOp::load_all_ops(extension, extension_ref).unwrap();
130 })
131}
132
133pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(extension);
135
136impl MakeRegisteredOp for LogicOp {
137 fn extension_id(&self) -> ExtensionId {
138 EXTENSION_ID.clone()
139 }
140
141 fn extension_ref(&self) -> Weak<Extension> {
142 Arc::downgrade(&EXTENSION)
143 }
144}
145
146fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
147 let true_val = ops::Value::true_val();
148 let false_val = ops::Value::false_val();
149 let inps: Option<Vec<bool>> = sorted_consts(consts)
150 .into_iter()
151 .map(|c| {
152 if c == &true_val {
153 Some(true)
154 } else if c == &false_val {
155 Some(false)
156 } else {
157 None
158 }
159 })
160 .collect();
161 let inps = inps?;
162 Some(inps)
163}
164
165#[cfg(test)]
166pub(crate) mod test {
167 use std::sync::Arc;
168
169 use super::{LogicOp, extension};
170 use crate::{
171 Extension,
172 extension::simple_op::{MakeOpDef, MakeRegisteredOp},
173 ops::Value,
174 };
175
176 use rstest::rstest;
177 use strum::IntoEnumIterator;
178
179 #[test]
180 fn test_logic_extension() {
181 let r: Arc<Extension> = extension();
182 assert_eq!(r.name() as &str, "logic");
183 assert_eq!(r.operations().count(), 5);
184
185 for op in LogicOp::iter() {
186 assert_eq!(
187 LogicOp::from_def(r.get_op(op.into()).unwrap(),).unwrap(),
188 op
189 );
190 }
191 }
192
193 #[test]
194 fn test_conversions() {
195 for o in LogicOp::iter() {
196 let ext_op = o.to_extension_op().unwrap();
197 assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o);
198 }
199 }
200
201 pub(crate) fn and_op() -> LogicOp {
203 LogicOp::And
204 }
205
206 pub(crate) fn or_op() -> LogicOp {
208 LogicOp::Or
209 }
210
211 #[rstest]
212 #[case(LogicOp::And, [true, true], true)]
213 #[case(LogicOp::And, [true, false], false)]
214 #[case(LogicOp::Or, [false, true], true)]
215 #[case(LogicOp::Or, [false, false], false)]
216 #[case(LogicOp::Eq, [true, false], false)]
217 #[case(LogicOp::Eq, [false, false], true)]
218 #[case(LogicOp::Not, [false], true)]
219 #[case(LogicOp::Not, [true], false)]
220 #[case(LogicOp::Xor, [true, false], true)]
221 #[case(LogicOp::Xor, [true, true], false)]
222 fn const_fold(
223 #[case] op: LogicOp,
224 #[case] ins: impl IntoIterator<Item = bool>,
225 #[case] out: bool,
226 ) {
227 use itertools::Itertools;
228
229 use crate::extension::ConstFold;
230 let in_vals = ins
231 .into_iter()
232 .enumerate()
233 .map(|(i, b)| (i.into(), Value::from_bool(b)))
234 .collect_vec();
235 assert_eq!(
236 Some(vec![(0.into(), Value::from_bool(out))]),
237 op.fold(&[(in_vals.len() as u64).into()], &in_vals)
238 );
239 }
240
241 #[rstest]
242 #[case(LogicOp::And, [Some(true), None], None)]
243 #[case(LogicOp::And, [Some(false), None], Some(false))]
244 #[case(LogicOp::Or, [None, Some(false)], None)]
245 #[case(LogicOp::Or, [None, Some(true)], Some(true))]
246 #[case(LogicOp::Eq, [None, Some(true)], None)]
247 #[case(LogicOp::Not, [None], None)]
248 #[case(LogicOp::Xor, [None, Some(true)], None)]
249 fn partial_const_fold(
250 #[case] op: LogicOp,
251 #[case] ins: impl IntoIterator<Item = Option<bool>>,
252 #[case] mb_out: Option<bool>,
253 ) {
254 use itertools::Itertools;
255
256 use crate::extension::ConstFold;
257 let in_vals0 = ins.into_iter().enumerate().collect_vec();
258 let num_args = in_vals0.len() as u64;
259 let in_vals = in_vals0
260 .into_iter()
261 .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
262 .collect_vec();
263 assert_eq!(
264 mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
265 op.fold(&[num_args.into()], &in_vals)
266 );
267 }
268}