hugr_core/std_extensions/
logic.rs

1//! Basic logical operations.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::{ConstFold, ConstFoldResult};
8use crate::ops::constant::ValueName;
9use crate::ops::Value;
10use crate::types::Signature;
11use crate::{
12    extension::{
13        prelude::bool_t,
14        simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError},
15        ExtensionId, OpDef, SignatureFunc,
16    },
17    ops,
18    types::type_param::TypeArg,
19    utils::sorted_consts,
20    Extension, IncomingPort,
21};
22use lazy_static::lazy_static;
23/// Name of extension false value.
24pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
25/// Name of extension true value.
26pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");
27
28impl ConstFold for LogicOp {
29    fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
30        match self {
31            Self::And => {
32                let inps = read_inputs(consts)?;
33                let res = inps.iter().all(|x| *x);
34                // We can only fold to true if we have a const for all our inputs.
35                (!res || inps.len() as u64 == 2)
36                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
37            }
38            Self::Or => {
39                let inps = read_inputs(consts)?;
40                let res = inps.iter().any(|x| *x);
41                // We can only fold to false if we have a const for all our inputs
42                (res || inps.len() as u64 == 2)
43                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
44            }
45            Self::Eq => {
46                let inps = read_inputs(consts)?;
47                let res = inps.iter().copied().reduce(|a, b| a == b)?;
48                // If we have only some inputs, we can still fold to false, but not to true
49                (!res || inps.len() as u64 == 2)
50                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
51            }
52            Self::Not => {
53                let inps = read_inputs(consts)?;
54                let res = inps.iter().all(|x| !*x);
55                (!res || inps.len() as u64 == 1)
56                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
57            }
58            Self::Xor => {
59                let inps = read_inputs(consts)?;
60                let res = inps.iter().fold(false, |acc, x| acc ^ *x);
61                (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
62            }
63        }
64    }
65}
66
67/// Logic extension operation definitions.
68#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
69#[allow(missing_docs)]
70#[non_exhaustive]
71pub enum LogicOp {
72    And,
73    Or,
74    Eq,
75    Not,
76    Xor,
77}
78
79impl MakeOpDef for LogicOp {
80    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
81        match self {
82            LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
83                Signature::new(vec![bool_t(); 2], vec![bool_t()])
84            }
85            LogicOp::Not => Signature::new_endo(vec![bool_t()]),
86        }
87        .into()
88    }
89
90    fn extension_ref(&self) -> Weak<Extension> {
91        Arc::downgrade(&EXTENSION)
92    }
93
94    fn description(&self) -> String {
95        match self {
96            LogicOp::And => "logical 'and'",
97            LogicOp::Or => "logical 'or'",
98            LogicOp::Eq => "test if bools are equal",
99            LogicOp::Not => "logical 'not'",
100            LogicOp::Xor => "logical 'xor'",
101        }
102        .to_string()
103    }
104
105    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
106        try_from_name(op_def.name(), op_def.extension_id())
107    }
108
109    fn extension(&self) -> ExtensionId {
110        EXTENSION_ID.to_owned()
111    }
112
113    fn post_opdef(&self, def: &mut OpDef) {
114        def.set_constant_folder(*self);
115    }
116}
117
118/// The extension identifier.
119pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");
120/// Extension version.
121pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
122
123/// Extension for basic logical operations.
124fn extension() -> Arc<Extension> {
125    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
126        LogicOp::load_all_ops(extension, extension_ref).unwrap();
127
128        extension
129            .add_value(FALSE_NAME, ops::Value::false_val())
130            .unwrap();
131        extension
132            .add_value(TRUE_NAME, ops::Value::true_val())
133            .unwrap();
134    })
135}
136
137lazy_static! {
138    /// Reference to the logic Extension.
139    pub static ref EXTENSION: Arc<Extension> = extension();
140}
141
142impl MakeRegisteredOp for LogicOp {
143    fn extension_id(&self) -> ExtensionId {
144        EXTENSION_ID.to_owned()
145    }
146
147    fn extension_ref(&self) -> Weak<Extension> {
148        Arc::downgrade(&EXTENSION)
149    }
150}
151
152fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
153    let true_val = ops::Value::true_val();
154    let false_val = ops::Value::false_val();
155    let inps: Option<Vec<bool>> = sorted_consts(consts)
156        .into_iter()
157        .map(|c| {
158            if c == &true_val {
159                Some(true)
160            } else if c == &false_val {
161                Some(false)
162            } else {
163                None
164            }
165        })
166        .collect();
167    let inps = inps?;
168    Some(inps)
169}
170
171#[cfg(test)]
172pub(crate) mod test {
173    use std::sync::Arc;
174
175    use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME};
176    use crate::{
177        extension::{
178            prelude::bool_t,
179            simple_op::{MakeOpDef, MakeRegisteredOp},
180        },
181        ops::{NamedOp, Value},
182        Extension,
183    };
184
185    use rstest::rstest;
186    use strum::IntoEnumIterator;
187
188    #[test]
189    fn test_logic_extension() {
190        let r: Arc<Extension> = extension();
191        assert_eq!(r.name() as &str, "logic");
192        assert_eq!(r.operations().count(), 5);
193
194        for op in LogicOp::iter() {
195            assert_eq!(
196                LogicOp::from_def(r.get_op(&op.name()).unwrap(),).unwrap(),
197                op
198            );
199        }
200    }
201
202    #[test]
203    fn test_conversions() {
204        for o in LogicOp::iter() {
205            let ext_op = o.to_extension_op().unwrap();
206            assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o);
207        }
208    }
209
210    #[test]
211    fn test_values() {
212        let r: Arc<Extension> = extension();
213        let false_val = r.get_value(&FALSE_NAME).unwrap();
214        let true_val = r.get_value(&TRUE_NAME).unwrap();
215
216        for v in [false_val, true_val] {
217            let simpl = v.typed_value().get_type();
218            assert_eq!(simpl, bool_t());
219        }
220    }
221
222    /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`]
223    pub(crate) fn and_op() -> LogicOp {
224        LogicOp::And
225    }
226
227    /// Generate a logic extension "or" operation over [`crate::prelude::bool_t()`]
228    pub(crate) fn or_op() -> LogicOp {
229        LogicOp::Or
230    }
231
232    #[rstest]
233    #[case(LogicOp::And, [true, true], true)]
234    #[case(LogicOp::And, [true, false], false)]
235    #[case(LogicOp::Or, [false, true], true)]
236    #[case(LogicOp::Or, [false, false], false)]
237    #[case(LogicOp::Eq, [true, false], false)]
238    #[case(LogicOp::Eq, [false, false], true)]
239    #[case(LogicOp::Not, [false], true)]
240    #[case(LogicOp::Not, [true], false)]
241    #[case(LogicOp::Xor, [true, false], true)]
242    #[case(LogicOp::Xor, [true, true], false)]
243    fn const_fold(
244        #[case] op: LogicOp,
245        #[case] ins: impl IntoIterator<Item = bool>,
246        #[case] out: bool,
247    ) {
248        use itertools::Itertools;
249
250        use crate::extension::ConstFold;
251        let in_vals = ins
252            .into_iter()
253            .enumerate()
254            .map(|(i, b)| (i.into(), Value::from_bool(b)))
255            .collect_vec();
256        assert_eq!(
257            Some(vec![(0.into(), Value::from_bool(out))]),
258            op.fold(&[(in_vals.len() as u64).into()], &in_vals)
259        );
260    }
261
262    #[rstest]
263    #[case(LogicOp::And, [Some(true), None], None)]
264    #[case(LogicOp::And, [Some(false), None], Some(false))]
265    #[case(LogicOp::Or, [None, Some(false)], None)]
266    #[case(LogicOp::Or, [None, Some(true)], Some(true))]
267    #[case(LogicOp::Eq, [None, Some(true)], None)]
268    #[case(LogicOp::Not, [None], None)]
269    #[case(LogicOp::Xor, [None, Some(true)], None)]
270    fn partial_const_fold(
271        #[case] op: LogicOp,
272        #[case] ins: impl IntoIterator<Item = Option<bool>>,
273        #[case] mb_out: Option<bool>,
274    ) {
275        use itertools::Itertools;
276
277        use crate::extension::ConstFold;
278        let in_vals0 = ins.into_iter().enumerate().collect_vec();
279        let num_args = in_vals0.len() as u64;
280        let in_vals = in_vals0
281            .into_iter()
282            .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
283            .collect_vec();
284        assert_eq!(
285            mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
286            op.fold(&[num_args.into()], &in_vals)
287        );
288    }
289}