aws_smt_ir/term/
sort_checking.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3use crate::{
4    fold::{Fold, Folder},
5    Constant, CoreOp, Ctx, ICoreOp, IOp, IQuantifier, ISort, IVar, Identifier, Index, Logic,
6    QualIdentifier, Quantifier, Term, Void, IUF, UF,
7};
8use std::fmt::{Debug, Display};
9
10/// Defines sort-checking behavior for a type.
11pub trait Sorted<L: Logic> {
12    /// Determines the sort of `self`. If `self` is a variable, then its sort is the variable's
13    /// sort; if `self` is a function application, then its sort is the function's return type.
14    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>>;
15}
16
17#[derive(Debug)]
18pub struct SortChecker<'a>(&'a mut Ctx);
19
20impl<L: Logic> Sorted<L> for Term<L> {
21    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
22        self.fold_with(&mut SortChecker(ctx))
23    }
24}
25
26impl<L: Logic<Var = QualIdentifier>> Sorted<L> for UF<Term<L>> {
27    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
28        ctx.return_sort(&self.func).cloned().ok_or_else(|| {
29            // TODO: this isn't really what it is, but it'll print out right
30            UnknownSort(Term::from(IVar::from(QualIdentifier::from(
31                self.func.clone(),
32            ))))
33        })
34    }
35}
36
37impl<L: Logic> Sorted<L> for IUF<L> {
38    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
39        self.as_ref().sort(ctx)
40    }
41}
42
43impl<L: Logic> Sorted<L> for Quantifier<Term<L>> {
44    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
45        match self {
46            Self::Forall(_, term) | Self::Exists(_, term) => term.sort(ctx),
47        }
48    }
49}
50
51impl<L: Logic<Var = QualIdentifier>> Sorted<L> for QualIdentifier {
52    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
53        match self {
54            Self::Simple { identifier } => {
55                match identifier.as_ref() {
56                    Identifier::Simple { symbol } => ctx
57                        .const_sort(symbol)
58                        .cloned()
59                        .ok_or_else(|| UnknownSort(Term::Variable(self.into()))),
60                    Identifier::Indexed { symbol, indices } => {
61                        // Check for (_ bvX n) literals
62                        match indices.as_slice() {
63                            [Index::Numeral(n)] if symbol.0.starts_with("bv") => {
64                                Ok(ISort::bitvec(n.clone()))
65                            }
66                            _ => Err(UnknownSort(Term::Variable(self.into()))),
67                        }
68                    }
69                }
70            }
71            Self::Sorted { sort, .. } => Ok(sort.clone()),
72        }
73    }
74}
75
76impl<L: Logic> Sorted<L> for CoreOp<Term<L>> {
77    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
78        if let Self::Ite(_, t, e) = self {
79            let sort = t.sort(ctx)?;
80            debug_assert_eq!(e.sort(ctx)?, sort, "ite branches must be of the same sort");
81            Ok(sort)
82        } else {
83            Ok(ISort::bool())
84        }
85    }
86}
87
88impl<L: Logic> Sorted<L> for IQuantifier<L> {
89    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
90        self.as_ref().sort(ctx)
91    }
92}
93
94impl<L: Logic> Sorted<L> for ICoreOp<L> {
95    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
96        self.as_ref().sort(ctx)
97    }
98}
99
100impl<L: Logic> Sorted<L> for IOp<L> {
101    fn sort(&self, ctx: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
102        self.as_ref().sort(ctx)
103    }
104}
105
106impl<L: Logic> Sorted<L> for Void {
107    fn sort(&self, _: &mut Ctx) -> Result<ISort, UnknownSort<Term<L>>> {
108        match *self {}
109    }
110}
111
112#[derive(Debug, thiserror::Error, PartialEq, Eq)]
113#[error("unknown sort for: {0}")]
114pub struct UnknownSort<T: Debug + Display>(pub T);
115
116impl<L: Logic> Folder<L> for SortChecker<'_> {
117    type Output = ISort;
118    type Error = UnknownSort<Term<L>>;
119
120    fn fold_const(&mut self, constant: crate::IConst) -> Result<Self::Output, Self::Error> {
121        match constant.as_ref() {
122            Constant::Numeral(_) => Ok(ISort::int()),
123            Constant::String(_) => Ok(ISort::string()),
124            Constant::Binary(bits) => Ok(ISort::bitvec(bits.len().into())),
125            Constant::Hexadecimal(digits) => Ok(ISort::bitvec((digits.len() * 4).into())),
126            Constant::Decimal(_) => Ok(ISort::real()),
127        }
128    }
129
130    fn fold_var(&mut self, var: crate::IVar<L::Var>) -> Result<Self::Output, Self::Error> {
131        var.sort(self.0)
132    }
133
134    fn fold_core_op(&mut self, op: ICoreOp<L>) -> Result<Self::Output, Self::Error> {
135        op.sort(self.0)
136    }
137
138    fn fold_theory_op(&mut self, op: crate::IOp<L>) -> Result<Self::Output, Self::Error> {
139        op.sort(self.0)
140    }
141
142    fn fold_uninterpreted_func(&mut self, uf: crate::IUF<L>) -> Result<Self::Output, Self::Error> {
143        uf.sort(self.0)
144    }
145
146    fn fold_let(&mut self, l: crate::ILet<L>) -> Result<Self::Output, Self::Error> {
147        l.term.sort(self.0)
148    }
149
150    fn fold_match(&mut self, m: crate::IMatch<L>) -> Result<Self::Output, Self::Error> {
151        Err(UnknownSort(m.into()))
152    }
153
154    fn fold_quantifier(&mut self, quantifier: IQuantifier<L>) -> Result<Self::Output, Self::Error> {
155        quantifier.sort(self.0)
156    }
157}