machine_check_common/iir/expr/
call.rs

1use std::fmt::Debug;
2
3use mck::{
4    abstr::{AbstractValue, BitvectorDomain},
5    concr::RConcreteBitvector,
6    forward::ReadWrite,
7    misc::{Join, RBound},
8    refin::{RefinementDomain, RefinementValue},
9    ThreeValued,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::iir::context::IFnContext;
14use crate::iir::description::IFnId;
15use crate::iir::expr::op::IMckExt;
16use crate::iir::{
17    expr::op::{IMckBinary, IMckUnary},
18    variable::IVarId,
19};
20use crate::iir::{join_limited, IAbstr, IRefin};
21use crate::ir_common::IrTypeArray;
22
23#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub enum IMckNew {
25    Bitvector(RConcreteBitvector),
26    BitvectorArray(IrTypeArray, IVarId),
27}
28
29impl IMckNew {
30    fn forward_interpret(&self, abstr: &IAbstr) -> AbstractValue {
31        match self {
32            IMckNew::Bitvector(bitvector) => {
33                AbstractValue::Bitvector(mck::abstr::RBitvector::single_value(*bitvector))
34            }
35            IMckNew::BitvectorArray(ty, element) => {
36                let element = *abstr.value(*element).expect_bitvector();
37                mck::abstr::AbstractValue::Array(mck::abstr::RArray::new_filled(
38                    RBound::new(ty.index_width),
39                    element,
40                ))
41            }
42        }
43    }
44
45    fn backward_interpret(&self, abstr: &IAbstr, refin: &mut IRefin, later: RefinementValue) {
46        match self {
47            IMckNew::Bitvector(_) => {
48                // nothing to propagate to
49            }
50            IMckNew::BitvectorArray(_ty, var_id) => {
51                // overlay the markings and propagate back
52                let later = later.expect_array();
53                let earlier = later.earlier_element();
54
55                join_limited(abstr, refin, *var_id, RefinementValue::Bitvector(earlier))
56            }
57        }
58    }
59}
60
61impl Debug for IMckNew {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Self::Bitvector(bitvector) => {
65                write!(f, "Bitvector::new({})", bitvector)
66            }
67            IMckNew::BitvectorArray(ty, element) => {
68                write!(
69                    f,
70                    "Bitvector::<{},{}>::new({:?})",
71                    ty.index_width, ty.element_width, element
72                )
73            }
74        }
75    }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
79pub struct IArrayRead {
80    pub base: IVarId,
81    pub index: IVarId,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
85pub struct IArrayWrite {
86    pub base: IVarId,
87    pub index: IVarId,
88    pub element: IVarId,
89}
90
91#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
92pub enum IExprCall {
93    Call(ICall),
94    MckUnary(IMckUnary),
95    MckBinary(IMckBinary),
96    MckExt(IMckExt),
97    MckNew(IMckNew),
98    BooleanNew(bool),
99    StdClone(IVarId),
100    ArrayRead(IArrayRead),
101    ArrayWrite(IArrayWrite),
102    Phi(IPhi),
103}
104
105#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub struct IPhi {
107    pub condition: IVarId,
108    pub then_var_id: IVarId,
109    pub else_var_id: IVarId,
110}
111
112impl IExprCall {
113    pub fn forward_interpret(&self, context: &IFnContext, abstr: &IAbstr) -> Option<AbstractValue> {
114        Some(match self {
115            IExprCall::Call(call) => call.forward_interpret(context, abstr),
116            IExprCall::MckUnary(unary) => unary.forward_interpret(abstr),
117            IExprCall::MckBinary(binary) => binary.forward_interpret(abstr),
118            IExprCall::MckExt(ext) => ext.forward_interpret(abstr),
119            IExprCall::MckNew(mck_new) => mck_new.forward_interpret(abstr),
120            IExprCall::BooleanNew(value) => AbstractValue::Boolean(
121                mck::abstr::Boolean::from_three_valued(ThreeValued::from_bool(*value)),
122            ),
123            IExprCall::ArrayRead(array_read) => {
124                let array = abstr.value(array_read.base).expect_array();
125                let index = abstr.value(array_read.index).expect_bitvector();
126
127                AbstractValue::Bitvector(array.read(*index))
128            }
129            IExprCall::ArrayWrite(array_write) => {
130                let array = abstr.value(array_write.base).expect_array();
131                let index = abstr.value(array_write.index).expect_bitvector();
132                let element = abstr.value(array_write.element).expect_bitvector();
133
134                AbstractValue::Array(array.write(*index, *element))
135            }
136            IExprCall::Phi(phi) => {
137                // join the left and right variable value
138                // at least one must be present, but not necessarily both
139                let left = abstr.value_opt(phi.then_var_id);
140                let right = abstr.value_opt(phi.else_var_id);
141
142                match (left, right) {
143                    (Some(left), Some(right)) => left.clone().join(right),
144                    (Some(left), None) => left.clone(),
145                    (None, Some(right)) => right.clone(),
146                    (None, None) => panic!("At least one phi variable should be present"),
147                }
148            }
149            IExprCall::StdClone(var_id) => {
150                // clone
151                abstr.value(*var_id).clone()
152            }
153        })
154    }
155    pub fn backward_interpret(
156        &self,
157        context: &IFnContext,
158        abstr: &IAbstr,
159        refin: &mut IRefin,
160        refin_later: RefinementValue,
161    ) {
162        match self {
163            IExprCall::Call(call) => call.backward_interpret(context, abstr, refin, refin_later),
164            IExprCall::MckUnary(unary) => unary.backward_interpret(abstr, refin, refin_later),
165            IExprCall::MckBinary(binary) => binary.backward_interpret(abstr, refin, refin_later),
166            IExprCall::MckExt(ext) => ext.backward_interpret(abstr, refin, refin_later),
167            IExprCall::MckNew(new) => new.backward_interpret(abstr, refin, refin_later),
168            IExprCall::BooleanNew(_) => {
169                // there is no variable to propagate to, do nothing
170            }
171            IExprCall::StdClone(var_id) => {
172                // limit and join previous
173
174                join_limited(abstr, refin, *var_id, refin_later);
175            }
176            IExprCall::Phi(phi) => {
177                // propagate into both
178                // the abstract value might not be present, limit manually, skipping when not present
179                if let Some(abstr_a) = abstr.value_opt(phi.then_var_id) {
180                    let refin_a = refin_later.clone().limit(abstr_a);
181                    refin.join_value(phi.then_var_id, refin_a);
182                }
183                if let Some(abstr_b) = abstr.value_opt(phi.else_var_id) {
184                    let refin_b = refin_later.clone().limit(abstr_b);
185                    refin.join_value(phi.else_var_id, refin_b);
186                }
187
188                // convert to condition and propagate
189                let condition_value = refin_later.to_condition();
190
191                join_limited(
192                    abstr,
193                    refin,
194                    phi.condition,
195                    RefinementValue::Boolean(condition_value),
196                )
197            }
198            IExprCall::ArrayRead(array_read) => {
199                let refin_element = refin_later.expect_bitvector();
200                let abstr_earlier = abstr.value(array_read.base).expect_array();
201                let abstr_index = abstr.value(array_read.index).expect_bitvector();
202                let (refin_earlier, refin_index) =
203                    mck::backward::ReadWrite::read((abstr_earlier, *abstr_index), *refin_element);
204
205                // we already have the abstract values, limit them here
206                let refin_earlier = refin_earlier.limit(abstr_earlier);
207                let refin_index = refin_index.limit(abstr_index);
208
209                refin.join_value(array_read.base, RefinementValue::Array(refin_earlier));
210                refin.join_value(array_read.index, RefinementValue::Bitvector(refin_index));
211            }
212            IExprCall::ArrayWrite(array_write) => {
213                let RefinementValue::Array(refin_later) = refin_later else {
214                    panic!("Array write later should be an array");
215                };
216
217                let abstr_earlier = abstr.value(array_write.base).expect_array();
218                let abstr_index = abstr.value(array_write.index).expect_bitvector();
219                let abstr_element = abstr.value(array_write.element).expect_bitvector();
220
221                let (refin_earlier, refin_index, refin_element) = mck::backward::ReadWrite::write(
222                    (abstr_earlier, *abstr_index, *abstr_element),
223                    refin_later.clone(),
224                );
225
226                // we already have the abstract values, limit them here
227                let refin_earlier = refin_earlier.limit(abstr_earlier);
228                let refin_index = refin_index.limit(abstr_index);
229                let refin_element = refin_element.limit(abstr_element);
230
231                refin.join_value(array_write.base, RefinementValue::Array(refin_earlier));
232                refin.join_value(array_write.index, RefinementValue::Bitvector(refin_index));
233                refin.join_value(
234                    array_write.element,
235                    RefinementValue::Bitvector(refin_element),
236                );
237            }
238        }
239    }
240}
241
242impl ICall {
243    pub fn forward_interpret(&self, context: &IFnContext, abstr: &IAbstr) -> AbstractValue {
244        let func = context.context.fn_with_id(self.func);
245        let mut input_values = Vec::new();
246
247        for var_id in self.args.iter().cloned() {
248            let input_value = abstr.value(var_id).clone();
249            input_values.push(input_value);
250        }
251
252        let (normal, panic) = func.call(context.context, input_values);
253        AbstractValue::Struct(vec![normal, panic])
254    }
255
256    pub fn backward_interpret(
257        &self,
258        context: &IFnContext,
259        abstr: &IAbstr,
260        refin: &mut IRefin,
261        refin_later: RefinementValue,
262    ) {
263        let func = context.context.fn_with_id(self.func);
264
265        let refin_later = refin_later.expect_struct();
266
267        let later_normal = refin_later[0].clone();
268        let later_panic = refin_later[1].clone();
269
270        let mut input_values = Vec::new();
271
272        for var_id in self.args.iter().cloned() {
273            let input_value = abstr.value(var_id).clone();
274            input_values.push(input_value);
275        }
276
277        let func_abstr = func.forward_interpret(context.context, input_values);
278
279        let func_refin =
280            func.backward_interpret(context.context, &func_abstr, later_normal, later_panic);
281
282        let refin_inputs = func.backward_earlier(&func_abstr, &func_refin);
283
284        for (refin_var_id, refin_input) in self.args.iter().zip(refin_inputs) {
285            join_limited(abstr, refin, *refin_var_id, refin_input)
286        }
287    }
288}
289
290impl Debug for IExprCall {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        match self {
293            IExprCall::Call(call) => call.fmt(f),
294            IExprCall::MckUnary(unary) => unary.fmt(f),
295            IExprCall::MckBinary(binary) => binary.fmt(f),
296            IExprCall::MckExt(ext) => ext.fmt(f),
297            IExprCall::MckNew(mck_new) => mck_new.fmt(f),
298            IExprCall::StdClone(var_id) => {
299                write!(f, "StdClone({:?})", var_id)
300            }
301            IExprCall::ArrayRead(array_read) => {
302                write!(f, "{:?}[{:?}]", array_read.base, array_read.index)
303            }
304            IExprCall::ArrayWrite(array_write) => {
305                write!(
306                    f,
307                    "({:?}[{:?}] <-- {:?})",
308                    array_write.base, array_write.index, array_write.element
309                )
310            }
311            IExprCall::BooleanNew(value) => write!(f, "Boolean({:?})", value),
312            IExprCall::Phi(phi) => {
313                write!(
314                    f,
315                    "{:?} ? {:?} : {:?}",
316                    phi.condition, phi.then_var_id, phi.else_var_id
317                )
318            }
319        }
320    }
321}
322
323#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
324pub struct ICall {
325    pub func: IFnId,
326    pub args: Vec<IVarId>,
327}
328
329impl Debug for ICall {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        write!(f, "{:?}(", self.func)?;
332        for arg in &self.args {
333            write!(f, "{:?}, ", arg)?;
334        }
335        write!(f, ")")
336    }
337}