hugr_core/std_extensions/
logic.rs

1//! Basic logical operations.
2
3use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::{ConstFold, ConstFoldResult};
8use crate::ops::constant::ValueName;
9use crate::ops::{OpName, Value};
10use crate::types::Signature;
11use crate::{
12    Extension, IncomingPort,
13    extension::{
14        ExtensionId, OpDef, SignatureFunc,
15        prelude::bool_t,
16        simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError, try_from_name},
17    },
18    ops,
19    types::type_param::TypeArg,
20    utils::sorted_consts,
21};
22/// Name of extension false value.
23pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
24/// Name of extension true value.
25pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");
26
27impl ConstFold for LogicOp {
28    fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
29        match self {
30            Self::And => {
31                let inps = read_inputs(consts)?;
32                let res = inps.iter().all(|x| *x);
33                // We can only fold to true if we have a const for all our inputs.
34                (!res || inps.len() as u64 == 2)
35                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
36            }
37            Self::Or => {
38                let inps = read_inputs(consts)?;
39                let res = inps.iter().any(|x| *x);
40                // We can only fold to false if we have a const for all our inputs
41                (res || inps.len() as u64 == 2)
42                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
43            }
44            Self::Eq => {
45                let inps = read_inputs(consts)?;
46                let res = inps.iter().copied().reduce(|a, b| a == b)?;
47                // If we have only some inputs, we can still fold to false, but not to true
48                (!res || inps.len() as u64 == 2)
49                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
50            }
51            Self::Not => {
52                let inps = read_inputs(consts)?;
53                let res = inps.iter().all(|x| !*x);
54                (!res || inps.len() as u64 == 1)
55                    .then_some(vec![(0.into(), ops::Value::from_bool(res))])
56            }
57            Self::Xor => {
58                let inps = read_inputs(consts)?;
59                let res = inps.iter().fold(false, |acc, x| acc ^ *x);
60                (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
61            }
62        }
63    }
64}
65
66/// Logic extension operation definitions.
67#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
68#[allow(missing_docs)]
69#[non_exhaustive]
70pub enum LogicOp {
71    And,
72    Or,
73    Eq,
74    Not,
75    Xor,
76}
77
78impl MakeOpDef for LogicOp {
79    fn opdef_id(&self) -> OpName {
80        <&'static str>::from(self).into()
81    }
82
83    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
84        match self {
85            LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
86                Signature::new(vec![bool_t(); 2], vec![bool_t()])
87            }
88            LogicOp::Not => Signature::new_endo(vec![bool_t()]),
89        }
90        .into()
91    }
92
93    fn extension_ref(&self) -> Weak<Extension> {
94        Arc::downgrade(&EXTENSION)
95    }
96
97    fn description(&self) -> String {
98        match self {
99            LogicOp::And => "logical 'and'",
100            LogicOp::Or => "logical 'or'",
101            LogicOp::Eq => "test if bools are equal",
102            LogicOp::Not => "logical 'not'",
103            LogicOp::Xor => "logical 'xor'",
104        }
105        .to_string()
106    }
107
108    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
109        try_from_name(op_def.name(), op_def.extension_id())
110    }
111
112    fn extension(&self) -> ExtensionId {
113        EXTENSION_ID.clone()
114    }
115
116    fn post_opdef(&self, def: &mut OpDef) {
117        def.set_constant_folder(*self);
118    }
119}
120
121/// The extension identifier.
122pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");
123/// Extension version.
124pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
125
126/// Extension for basic logical operations.
127fn extension() -> Arc<Extension> {
128    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
129        LogicOp::load_all_ops(extension, extension_ref).unwrap();
130    })
131}
132
133/// Reference to the logic Extension.
134pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(extension);
135
136impl MakeRegisteredOp for LogicOp {
137    fn extension_id(&self) -> ExtensionId {
138        EXTENSION_ID.clone()
139    }
140
141    fn extension_ref(&self) -> Weak<Extension> {
142        Arc::downgrade(&EXTENSION)
143    }
144}
145
146fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
147    let true_val = ops::Value::true_val();
148    let false_val = ops::Value::false_val();
149    let inps: Option<Vec<bool>> = sorted_consts(consts)
150        .into_iter()
151        .map(|c| {
152            if c == &true_val {
153                Some(true)
154            } else if c == &false_val {
155                Some(false)
156            } else {
157                None
158            }
159        })
160        .collect();
161    let inps = inps?;
162    Some(inps)
163}
164
165#[cfg(test)]
166pub(crate) mod test {
167    use std::sync::Arc;
168
169    use super::{LogicOp, extension};
170    use crate::{
171        Extension,
172        extension::simple_op::{MakeOpDef, MakeRegisteredOp},
173        ops::Value,
174    };
175
176    use rstest::rstest;
177    use strum::IntoEnumIterator;
178
179    #[test]
180    fn test_logic_extension() {
181        let r: Arc<Extension> = extension();
182        assert_eq!(r.name() as &str, "logic");
183        assert_eq!(r.operations().count(), 5);
184
185        for op in LogicOp::iter() {
186            assert_eq!(
187                LogicOp::from_def(r.get_op(op.into()).unwrap(),).unwrap(),
188                op
189            );
190        }
191    }
192
193    #[test]
194    fn test_conversions() {
195        for o in LogicOp::iter() {
196            let ext_op = o.to_extension_op().unwrap();
197            assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o);
198        }
199    }
200
201    /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`]
202    pub(crate) fn and_op() -> LogicOp {
203        LogicOp::And
204    }
205
206    /// Generate a logic extension "or" operation over [`crate::prelude::bool_t()`]
207    pub(crate) fn or_op() -> LogicOp {
208        LogicOp::Or
209    }
210
211    #[rstest]
212    #[case(LogicOp::And, [true, true], true)]
213    #[case(LogicOp::And, [true, false], false)]
214    #[case(LogicOp::Or, [false, true], true)]
215    #[case(LogicOp::Or, [false, false], false)]
216    #[case(LogicOp::Eq, [true, false], false)]
217    #[case(LogicOp::Eq, [false, false], true)]
218    #[case(LogicOp::Not, [false], true)]
219    #[case(LogicOp::Not, [true], false)]
220    #[case(LogicOp::Xor, [true, false], true)]
221    #[case(LogicOp::Xor, [true, true], false)]
222    fn const_fold(
223        #[case] op: LogicOp,
224        #[case] ins: impl IntoIterator<Item = bool>,
225        #[case] out: bool,
226    ) {
227        use itertools::Itertools;
228
229        use crate::extension::ConstFold;
230        let in_vals = ins
231            .into_iter()
232            .enumerate()
233            .map(|(i, b)| (i.into(), Value::from_bool(b)))
234            .collect_vec();
235        assert_eq!(
236            Some(vec![(0.into(), Value::from_bool(out))]),
237            op.fold(&[(in_vals.len() as u64).into()], &in_vals)
238        );
239    }
240
241    #[rstest]
242    #[case(LogicOp::And, [Some(true), None], None)]
243    #[case(LogicOp::And, [Some(false), None], Some(false))]
244    #[case(LogicOp::Or, [None, Some(false)], None)]
245    #[case(LogicOp::Or, [None, Some(true)], Some(true))]
246    #[case(LogicOp::Eq, [None, Some(true)], None)]
247    #[case(LogicOp::Not, [None], None)]
248    #[case(LogicOp::Xor, [None, Some(true)], None)]
249    fn partial_const_fold(
250        #[case] op: LogicOp,
251        #[case] ins: impl IntoIterator<Item = Option<bool>>,
252        #[case] mb_out: Option<bool>,
253    ) {
254        use itertools::Itertools;
255
256        use crate::extension::ConstFold;
257        let in_vals0 = ins.into_iter().enumerate().collect_vec();
258        let num_args = in_vals0.len() as u64;
259        let in_vals = in_vals0
260            .into_iter()
261            .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
262            .collect_vec();
263        assert_eq!(
264            mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
265            op.fold(&[num_args.into()], &in_vals)
266        );
267    }
268}