open_vaf/analysis/
constant_folding.rs

1//  * ******************************************************************************************
2//  * Copyright (c) 2019 Pascal Kuthe. This file is part of the OpenVAF project.
3//  * It is subject to the license terms in the LICENSE file found in the top-level directory
4//  *  of this distribution and at  https://gitlab.com/DSPOM/OpenVAF/blob/master/LICENSE.
5//  *  No part of OpenVAF, including this file, may be copied, modified, propagated, or
6//  *  distributed except according to the terms contained in the LICENSE file.
7//  * *******************************************************************************************
8
9#![allow(clippy::float_cmp)]
10#![allow(clippy::similar_names)]
11
12
13
14use crate::analysis::data_flow::reaching_variables::UseDefGraph;
15use crate::ast::UnaryOperator;
16
17use crate::cfg::{BasicBlock, Terminator};
18use crate::data_structures::BitSet;
19use crate::ir::cfg::{BasicBlockId, ControlFlowGraph};
20use crate::ir::mir::Mir;
21use crate::ir::{
22    BuiltInFunctionCall1p, BuiltInFunctionCall2p, IntegerExpressionId, ParameterId,
23    RealExpressionId, StatementId, StringExpressionId, VariableId,
24};
25use crate::literals::StringLiteral;
26use crate::mir::{
27    ComparisonOperator, ExpressionId, IntegerBinaryOperator, IntegerExpression, RealBinaryOperator,
28    RealExpression, Statement, StringExpression,
29};
30use bitflags::_core::option::Option::Some;
31use index_vec::IndexVec;
32use log::*;
33use rustc_hash::FxHashMap;
34
35#[derive(Clone, Debug, Default)]
36pub struct ConstantFoldState {
37    pub real_definitions: FxHashMap<StatementId, f64>,
38    pub integer_definitions: FxHashMap<StatementId, i64>,
39    pub string_definitions: FxHashMap<StatementId, StringLiteral>,
40    pub real_parameters: FxHashMap<ParameterId, f64>,
41    pub int_parameters: FxHashMap<ParameterId, i64>,
42    pub string_parameters: FxHashMap<ParameterId, StringLiteral>,
43}
44
45impl ConstantResolver for () {
46    #[inline(always)]
47    fn get_real_variable_value(&mut self, _var: VariableId) -> Option<f64> {
48        None
49    }
50
51    #[inline(always)]
52    fn get_int_variable_value(&mut self, _var: VariableId) -> Option<i64> {
53        None
54    }
55
56    #[inline(always)]
57    fn get_str_variable_value(&mut self, _var: VariableId) -> Option<StringLiteral> {
58        None
59    }
60
61    #[inline(always)]
62    fn get_real_parameter_value(&mut self, _param: ParameterId) -> Option<f64> {
63        None
64    }
65
66    #[inline(always)]
67    fn get_int_parameter_value(&mut self, _param: ParameterId) -> Option<i64> {
68        None
69    }
70
71    #[inline(always)]
72    fn get_str_parameter_value(&mut self, _param: ParameterId) -> Option<StringLiteral> {
73        None
74    }
75}
76
77pub struct ConstantPropagator<'lt> {
78    known_values: &'lt ConstantFoldState,
79    dependencys_before: &'lt BitSet<StatementId>,
80    dependencys_after: &'lt mut BitSet<StatementId>,
81    variables_assignments: &'lt IndexVec<VariableId, BitSet<StatementId>>,
82}
83
84impl<'lt> ConstantResolver for ConstantPropagator<'lt> {
85    #[inline]
86    fn get_real_variable_value(&mut self, var: VariableId) -> Option<f64> {
87        let mut definitions = self
88            .dependencys_before
89            .intersection(&self.variables_assignments[var])
90            .map(|id| self.known_values.real_definitions.get(&id));
91
92        // TODO constant fold default values
93        let value = *definitions.next().flatten()?;
94        if definitions.any(|x| Some(&value) != x) {
95            return None;
96        }
97        self.dependencys_after
98            .difference_with(&self.variables_assignments[var]);
99
100        Some(value)
101    }
102
103    #[inline]
104    fn get_int_variable_value(&mut self, var: VariableId) -> Option<i64> {
105        let mut definitions = self
106            .dependencys_before
107            .intersection(&self.variables_assignments[var])
108            .map(|id| self.known_values.integer_definitions.get(&id));
109
110        // TODO constant fold default values
111        let value = *definitions.next().flatten()?;
112        if definitions.any(|x| Some(&value) != x) {
113            return None;
114        }
115        self.dependencys_after
116            .difference_with(&self.variables_assignments[var]);
117
118        Some(value)
119    }
120
121    #[inline]
122    fn get_str_variable_value(&mut self, var: VariableId) -> Option<StringLiteral> {
123        let mut definitions = self
124            .dependencys_before
125            .intersection(&self.variables_assignments[var])
126            .map(|id| self.known_values.string_definitions.get(&id));
127
128        let value = *definitions.next().flatten()?;
129        if definitions.any(|x| Some(&value) != x) {
130            return None;
131        }
132        self.dependencys_after
133            .difference_with(&self.variables_assignments[var]);
134        Some(value)
135    }
136
137    #[inline]
138    fn get_real_parameter_value(&mut self, param: ParameterId) -> Option<f64> {
139        self.known_values.real_parameters.get(&param).copied()
140    }
141
142    #[inline]
143    fn get_int_parameter_value(&mut self, param: ParameterId) -> Option<i64> {
144        self.known_values.int_parameters.get(&param).copied()
145    }
146
147    #[inline]
148    fn get_str_parameter_value(&mut self, param: ParameterId) -> Option<StringLiteral> {
149        self.known_values.string_parameters.get(&param).copied()
150    }
151}
152
153pub trait ConstantResolver {
154    fn get_real_variable_value(&mut self, var: VariableId) -> Option<f64>;
155    fn get_int_variable_value(&mut self, var: VariableId) -> Option<i64>;
156    fn get_str_variable_value(&mut self, var: VariableId) -> Option<StringLiteral>;
157    fn get_real_parameter_value(&mut self, param: ParameterId) -> Option<f64>;
158    fn get_int_parameter_value(&mut self, param: ParameterId) -> Option<i64>;
159    fn get_str_parameter_value(&mut self, param: ParameterId) -> Option<StringLiteral>;
160}
161
162pub fn real_constant_fold(
163    fold: &mut impl ConstantFolder,
164    resolver: &mut impl ConstantResolver,
165    expr: RealExpressionId,
166) -> Option<f64> {
167    let res = match fold.mir()[expr].contents {
168        RealExpression::Literal(val) => return Some(val),
169        RealExpression::VariableReference(var) => resolver.get_real_variable_value(var)?,
170        RealExpression::ParameterReference(param) => resolver.get_real_parameter_value(param)?,
171
172        RealExpression::BinaryOperator(lhs_id, op, rhs_id) => {
173            let lhs = fold.real_constant_fold(resolver, lhs_id);
174            let rhs = fold.real_constant_fold(resolver, rhs_id);
175
176            /*
177                This fold also does some algebraic simplification here for expressions where only one operator can be folded.
178                Algebraic simiplifications are technically not IEEE-754 compliant because you have to special case NAN (NAN*0.0 = NAN)
179                In practice if you have a term z=x*y where x is known at compile time to be == 0.0 then you probably do intend z to be always 0.0
180                We do this because this non standard behavior is required to correctly calculate some derivatives
181                and it also dramatically speeds up compile times for code with lots of derivatives.
182                TODO properly document this
183                Note that simplifying 0.0 + x or 1.0*x to x is perfectly fine
184                These simplifications are however also not very interesting but are done because its faster to do here than at runtime / let llvm do it.
185                All simplifications that are not standard compliant will be marked with a comment
186            */
187
188            match op.contents {
189                RealBinaryOperator::Sum => match (lhs, rhs) {
190                    (Some(lhs), Some(rhs)) => lhs + rhs,
191                    (Some(lhs), None) if lhs == 0.0 => {
192                        fold.resolve_to_real_subexpressions(expr, rhs_id);
193                        return None;
194                    }
195                    (None, Some(rhs)) if rhs == 0.0 => {
196                        fold.resolve_to_real_subexpressions(expr, lhs_id);
197                        return None;
198                    }
199                    (_, _) => return None,
200                },
201                RealBinaryOperator::Subtract => match (lhs, rhs) {
202                    (Some(lhs), Some(rhs)) => lhs - rhs,
203                    (None, Some(rhs)) if rhs == 0.0 => {
204                        fold.resolve_to_real_subexpressions(expr, lhs_id);
205                        return None;
206                    }
207                    (_, _) => return None,
208                },
209                RealBinaryOperator::Multiply => match (lhs, rhs) {
210                    //not IEEE-754 compliant
211                    (Some(x), _) | (_, Some(x)) if x == 0.0 => 0.0,
212
213                    (Some(lhs), Some(rhs)) => lhs * rhs,
214
215                    (Some(lhs), None) if lhs == 1.0 => {
216                        fold.resolve_to_real_subexpressions(expr, rhs_id);
217                        return None;
218                    }
219
220                    (None, Some(rhs)) if rhs == 1.0 => {
221                        fold.resolve_to_real_subexpressions(expr, rhs_id);
222                        return None;
223                    }
224
225                    (_, _) => return None,
226                },
227                RealBinaryOperator::Divide => match (lhs, rhs) {
228                    //not IEEE-754 compliant
229                    (Some(lhs), _) if lhs == 0.0 => 0.0,
230                    (Some(lhs), Some(rhs)) => lhs / rhs,
231
232                    (None, Some(rhs)) if rhs == 1.0 => {
233                        fold.resolve_to_real_subexpressions(expr, lhs_id);
234                        return None;
235                    }
236                    (_, _) => return None,
237                },
238                RealBinaryOperator::Exponent => match (lhs, rhs) {
239                    //not IEEE-754 compliant
240                    (Some(lhs), _) if lhs == 0.0 => 0.0,
241
242                    //not IEEE-754 compliant
243                    (Some(lhs), _) if lhs == 1.0 => 1.0,
244
245                    //not IEEE-754 compliant
246                    (None, Some(rhs)) if rhs == 0.0 => 1.0,
247
248                    (None, Some(rhs)) if rhs == 1.0 => {
249                        fold.resolve_to_real_subexpressions(expr, lhs_id);
250                        return None;
251                    }
252
253                    (Some(lhs), Some(rhs)) => lhs.powf(rhs),
254
255                    (None, Some(rhs)) if rhs == 1.0 => {
256                        fold.resolve_to_real_subexpressions(expr, lhs_id);
257                        return None;
258                    }
259
260                    (_, _) => return None,
261                },
262
263                RealBinaryOperator::Modulus => {
264                    let lhs = lhs?;
265                    if lhs == 0.0 {
266                        0.0
267                    } else {
268                        lhs % rhs?
269                    }
270                }
271            }
272        }
273
274        RealExpression::Negate(_, val) => -fold.real_constant_fold(resolver, val)?,
275
276        RealExpression::Condition(condition, _, true_val_id, _, false_val_id) => {
277            let condition = fold.int_constant_fold(resolver, condition);
278            let true_val = fold.real_constant_fold(resolver, true_val_id);
279            let false_val = fold.real_constant_fold(resolver, false_val_id);
280
281            if condition? != 0 {
282                if let Some(true_val) = true_val {
283                    true_val
284                } else {
285                    fold.resolve_to_real_subexpressions(expr, true_val_id);
286                    return None;
287                }
288            } else if let Some(false_val) = false_val {
289                false_val
290            } else {
291                fold.resolve_to_real_subexpressions(expr, false_val_id);
292                return None;
293            }
294        }
295
296        RealExpression::BuiltInFunctionCall1p(call, arg) => {
297            let arg = fold.real_constant_fold(resolver, arg)?;
298            match call {
299                BuiltInFunctionCall1p::Ln => arg.ln(),
300                BuiltInFunctionCall1p::Sqrt => arg.sqrt(),
301                BuiltInFunctionCall1p::Exp(_) /* Whether this is a limexp or exp doesnt matter for constant eval*/ => arg.exp(),
302                BuiltInFunctionCall1p::Log => arg.log10(),
303                BuiltInFunctionCall1p::Abs => arg.abs(),
304                BuiltInFunctionCall1p::Floor => arg.floor(),
305                BuiltInFunctionCall1p::Ceil => arg.ceil(),
306                BuiltInFunctionCall1p::Sin => arg.sin(),
307                BuiltInFunctionCall1p::Cos => arg.cos(),
308                BuiltInFunctionCall1p::Tan => arg.tan(),
309                BuiltInFunctionCall1p::ArcSin => arg.asin(),
310                BuiltInFunctionCall1p::ArcCos => arg.acos(),
311                BuiltInFunctionCall1p::ArcTan => arg.atan(),
312                BuiltInFunctionCall1p::SinH => arg.sinh(),
313                BuiltInFunctionCall1p::CosH => arg.cosh(),
314                BuiltInFunctionCall1p::TanH => arg.tanh(),
315                BuiltInFunctionCall1p::ArcSinH => arg.asinh(),
316                BuiltInFunctionCall1p::ArcCosH => arg.acosh(),
317                BuiltInFunctionCall1p::ArcTanH => arg.atanh(),
318            }
319        }
320
321        RealExpression::BuiltInFunctionCall2p(call, arg1_id, arg2_id) => {
322            let arg1 = fold.real_constant_fold(resolver, arg1_id);
323            let arg2 = fold.real_constant_fold(resolver, arg2_id);
324            match call {
325                BuiltInFunctionCall2p::Pow => {
326                    match (arg1, arg2) {
327                        //not IEEE-754 compliant
328                        (Some(arg1), _) if arg1 == 0.0 => 0.0,
329
330                        //not IEEE-754 compliant
331                        (Some(arg1), _) if arg1 == 1.0 => 1.0,
332
333                        //not IEEE-754 compliant
334                        (None, Some(arg2)) if arg2 == 0.0 => 1.0,
335
336                        (None, Some(arg2)) if arg2 == 1.0 => {
337                            fold.resolve_to_real_subexpressions(expr, arg1_id);
338                            return None;
339                        }
340
341                        (Some(arg1), Some(arg2)) => arg1.powf(arg2),
342
343                        (_, _) => return None,
344                    }
345                }
346                BuiltInFunctionCall2p::Hypot => arg1?.hypot(arg2?),
347                BuiltInFunctionCall2p::Min => arg1?.min(arg2?),
348                BuiltInFunctionCall2p::Max => arg1?.max(arg2?),
349                BuiltInFunctionCall2p::ArcTan2 => arg1?.atan2(arg2?),
350            }
351        }
352
353        RealExpression::IntegerConversion(expr) => fold.int_constant_fold(resolver, expr)? as f64,
354
355        RealExpression::Vt(Some(_temp)) => {
356            //TODO abstract over constants
357            return None;
358        }
359
360        //Temperature/Sim parameters/Branches may be added in the future if there is any demand for it but it doesnt seem useful to me
361        RealExpression::Temperature
362        | RealExpression::SimParam(_, _)
363        | RealExpression::Vt(None)
364        | RealExpression::BranchAccess(_, _, _)
365        | RealExpression::Noise(_, _) => return None,
366    };
367    Some(res)
368}
369
370pub fn int_constant_fold(
371    fold: &mut impl ConstantFolder,
372    resolver: &mut impl ConstantResolver,
373    expr: IntegerExpressionId,
374) -> Option<i64> {
375    let res = match fold.mir()[expr].contents {
376        IntegerExpression::Literal(val) => return Some(val),
377        IntegerExpression::ParameterReference(param) => resolver.get_int_parameter_value(param)?,
378        IntegerExpression::VariableReference(var) => resolver.get_int_variable_value(var)?,
379
380        IntegerExpression::Abs(val) => fold.int_constant_fold(resolver, val)?.abs(),
381
382        IntegerExpression::Min(arg1, arg2) => {
383            let arg1 = fold.int_constant_fold(resolver, arg1);
384            let arg2 = fold.int_constant_fold(resolver, arg2);
385            arg1?.min(arg2?)
386        }
387
388        IntegerExpression::Max(arg1, arg2) => {
389            let arg1 = fold.int_constant_fold(resolver, arg1);
390            let arg2 = fold.int_constant_fold(resolver, arg2);
391            arg1?.max(arg2?)
392        }
393
394        IntegerExpression::BinaryOperator(lhs_id, op, rhs_id) => {
395            let lhs = fold.int_constant_fold(resolver, lhs_id);
396            let rhs = fold.int_constant_fold(resolver, rhs_id);
397            match op.contents {
398                IntegerBinaryOperator::Sum => match (lhs, rhs) {
399                    (Some(lhs), Some(rhs)) => lhs + rhs,
400                    (Some(0), None) => {
401                        fold.resolve_to_int_subexpressions(expr, rhs_id);
402                        return None;
403                    }
404                    (None, Some(0)) => {
405                        fold.resolve_to_int_subexpressions(expr, lhs_id);
406                        return None;
407                    }
408                    (_, _) => return None,
409                },
410                IntegerBinaryOperator::Subtract => match (lhs, rhs) {
411                    (Some(lhs), Some(rhs)) => lhs - rhs,
412
413                    (None, Some(0)) => {
414                        fold.resolve_to_int_subexpressions(expr, lhs_id);
415                        return None;
416                    }
417
418                    (_, _) => return None,
419                },
420                IntegerBinaryOperator::Multiply => match (lhs, rhs) {
421                    (Some(0), _) | (_, Some(0)) => 0,
422                    (Some(lhs), Some(rhs)) => lhs * rhs,
423                    (Some(1), None) => {
424                        fold.resolve_to_int_subexpressions(expr, rhs_id);
425                        return None;
426                    }
427                    (None, Some(1)) => {
428                        fold.resolve_to_int_subexpressions(expr, lhs_id);
429                        return None;
430                    }
431                    (_, _) => return None,
432                },
433                IntegerBinaryOperator::Divide => match (lhs, rhs) {
434                    (Some(0), _) => 0,
435                    (Some(lhs), Some(rhs)) => lhs / rhs,
436                    (None, Some(1)) => {
437                        fold.resolve_to_int_subexpressions(expr, lhs_id);
438                        return None;
439                    }
440                    (_, _) => return None,
441                },
442                IntegerBinaryOperator::Exponent => match (lhs, rhs) {
443                    (Some(0), _) => 0,
444                    (Some(1), _) => 1,
445
446                    (None, Some(0)) => 1,
447                    (None, Some(1)) => {
448                        fold.resolve_to_int_subexpressions(expr, lhs_id);
449                        return None;
450                    }
451
452                    // TODO Proper error on overflow
453                    (Some(lhs), Some(rhs)) if rhs >= 0 => lhs.pow(rhs as u32),
454
455                    (Some(_), Some(_)) => 0,
456
457                    (_, _) => return None,
458                },
459
460                IntegerBinaryOperator::Modulus => {
461                    let lhs = lhs?;
462                    if lhs == 0 {
463                        0
464                    } else {
465                        lhs % rhs?
466                    }
467                }
468                IntegerBinaryOperator::ShiftLeft => {
469                    if lhs == Some(0) {
470                        0
471                    } else if rhs == Some(0) && lhs == None {
472                        fold.resolve_to_int_subexpressions(expr, lhs_id);
473                        return None;
474                    } else {
475                        lhs? << rhs?
476                    }
477                }
478                IntegerBinaryOperator::ShiftRight => {
479                    if lhs == Some(0) {
480                        0
481                    } else if rhs == Some(0) && lhs == None {
482                        fold.resolve_to_int_subexpressions(expr, lhs_id);
483                        return None;
484                    } else {
485                        lhs? >> rhs?
486                    }
487                }
488                IntegerBinaryOperator::Xor => lhs? ^ rhs?,
489                IntegerBinaryOperator::NXor => !(lhs? ^ rhs?),
490                IntegerBinaryOperator::And => {
491                    if lhs == Some(0) || rhs == Some(0) {
492                        0
493                    } else {
494                        lhs? & rhs?
495                    }
496                }
497                IntegerBinaryOperator::Or => {
498                    if lhs == Some(0) && rhs.is_none() {
499                        fold.resolve_to_int_subexpressions(expr, rhs_id);
500                        return None;
501                    } else if rhs == Some(0) && lhs.is_none() {
502                        fold.resolve_to_int_subexpressions(expr, lhs_id);
503                        return None;
504                    } else {
505                        lhs? | rhs?
506                    }
507                }
508                IntegerBinaryOperator::LogicOr => match (lhs, rhs) {
509                    (Some(0), Some(0)) => 0,
510                    (Some(_), Some(_)) => 1,
511                    (None, None) => return None,
512                    (Some(0), None) => {
513                        fold.resolve_to_int_subexpressions(expr, rhs_id);
514                        return None;
515                    }
516                    (None, Some(0)) => {
517                        fold.resolve_to_int_subexpressions(expr, lhs_id);
518                        return None;
519                    }
520                    (Some(_), None) | (None, Some(_)) => 1,
521                },
522                IntegerBinaryOperator::LogicAnd => match (lhs, rhs) {
523                    (Some(0), Some(0)) => 0,
524                    (Some(0), _) | (_, Some(0)) => 0,
525                    (Some(_), Some(_)) => 1,
526                    (None, None) => return None,
527                    (Some(_), None) => {
528                        fold.resolve_to_int_subexpressions(expr, rhs_id);
529                        return None;
530                    }
531                    (None, Some(_)) => {
532                        fold.resolve_to_int_subexpressions(expr, lhs_id);
533                        return None;
534                    }
535                },
536            }
537        }
538
539        IntegerExpression::UnaryOperator(op, arg) => {
540            let arg = fold.int_constant_fold(resolver, arg)?;
541            match op.contents {
542                UnaryOperator::BitNegate => !arg,
543                UnaryOperator::LogicNegate => (arg == 0) as i64,
544                UnaryOperator::ArithmeticNegate => -arg,
545                UnaryOperator::ExplicitPositive => arg,
546            }
547        }
548
549        IntegerExpression::IntegerComparison(lhs, op, rhs) => {
550            let lhs = fold.int_constant_fold(resolver, lhs);
551            let rhs = fold.int_constant_fold(resolver, rhs);
552            let (lhs, rhs) = (lhs?, rhs?);
553            let res = match op.contents {
554                ComparisonOperator::LessThen => lhs < rhs,
555                ComparisonOperator::LessEqual => lhs <= rhs,
556                ComparisonOperator::GreaterThen => lhs > rhs,
557                ComparisonOperator::GreaterEqual => lhs >= rhs,
558                ComparisonOperator::LogicEqual => lhs == rhs,
559                ComparisonOperator::LogicalNotEqual => lhs != rhs,
560            };
561            res as i64
562        }
563
564        IntegerExpression::RealComparison(lhs, op, rhs) => {
565            let lhs = fold.real_constant_fold(resolver, lhs);
566            let rhs = fold.real_constant_fold(resolver, rhs);
567            let (lhs, rhs) = (lhs?, rhs?);
568            let res = match op.contents {
569                ComparisonOperator::LessThen => lhs < rhs,
570                ComparisonOperator::LessEqual => lhs <= rhs,
571                ComparisonOperator::GreaterThen => lhs > rhs,
572                ComparisonOperator::GreaterEqual => lhs >= rhs,
573                ComparisonOperator::LogicEqual => lhs == rhs,
574                ComparisonOperator::LogicalNotEqual => lhs != rhs,
575            };
576            res as i64
577        }
578
579        IntegerExpression::StringEq(lhs, rhs) => {
580            let lhs = fold.string_constant_fold(resolver, lhs);
581            let rhs = fold.string_constant_fold(resolver, rhs);
582            (lhs? == rhs?) as i64
583        }
584
585        IntegerExpression::StringNEq(lhs, rhs) => {
586            let lhs = fold.string_constant_fold(resolver, lhs);
587            let rhs = fold.string_constant_fold(resolver, rhs);
588            (lhs? != rhs?) as i64
589        }
590
591        IntegerExpression::Condition(condition, _, true_val_id, _, false_val_id) => {
592            let condition = fold.int_constant_fold(resolver, condition);
593            let true_val = fold.int_constant_fold(resolver, true_val_id);
594            let false_val = fold.int_constant_fold(resolver, false_val_id);
595
596            if condition? != 0 {
597                if let Some(true_val) = true_val {
598                    true_val
599                } else {
600                    fold.resolve_to_int_subexpressions(expr, true_val_id);
601                    return None;
602                }
603            } else if let Some(false_val) = false_val {
604                false_val
605            } else {
606                fold.resolve_to_int_subexpressions(expr, false_val_id);
607                return None;
608            }
609        }
610
611        IntegerExpression::RealCast(val) => fold.real_constant_fold(resolver, val)?.round() as i64,
612
613        //TODO system function call constant fold
614        IntegerExpression::PortConnected(_)
615        | IntegerExpression::ParamGiven(_)
616        | IntegerExpression::NetReference(_)
617        | IntegerExpression::PortReference(_)
618        | IntegerExpression::FunctionCall(_, _) => return None,
619    };
620
621    Some(res)
622}
623
624pub fn string_constant_fold(
625    fold: &mut impl ConstantFolder,
626    resolver: &mut impl ConstantResolver,
627    expr: StringExpressionId,
628) -> Option<StringLiteral> {
629    Some(match fold.mir()[expr].contents {
630        StringExpression::Literal(val) => return Some(val),
631        StringExpression::VariableReference(var) => resolver.get_str_variable_value(var)?,
632        StringExpression::ParameterReference(param) => resolver.get_str_parameter_value(param)?,
633
634        StringExpression::Condition(condition, _, true_val, _, false_val) => {
635            let condition = fold.int_constant_fold(resolver, condition);
636            let true_val = fold.string_constant_fold(resolver, true_val);
637            let false_val = fold.string_constant_fold(resolver, false_val);
638
639            if condition? == 0 {
640                false_val?
641            } else {
642                true_val?
643            }
644        }
645        // TODO system function call constant fold
646        StringExpression::SimParam(_) => return None,
647    })
648}
649
650pub struct ReadingConstantFold<'lt>(pub &'lt Mir);
651
652impl<'lt> ConstantFolder for ReadingConstantFold<'lt> {
653    fn resolve_to_string_subexpressions(
654        &mut self,
655        _dst: StringExpressionId,
656        _newval: StringExpressionId,
657    ) {
658    }
659
660    fn resolve_to_int_subexpressions(
661        &mut self,
662        _dst: IntegerExpressionId,
663        _newval: IntegerExpressionId,
664    ) {
665    }
666
667    fn resolve_to_real_subexpressions(
668        &mut self,
669        _dst: RealExpressionId,
670        _newval: RealExpressionId,
671    ) {
672    }
673
674    fn mir(&self) -> &Mir {
675        self.0
676    }
677}
678
679pub struct IntermediateWritingConstantFold<'lt>(pub &'lt mut Mir);
680
681impl<'lt> ConstantFolder for IntermediateWritingConstantFold<'lt> {
682    fn real_constant_fold(
683        &mut self,
684        resolver: &mut impl ConstantResolver,
685        expr: RealExpressionId,
686    ) -> Option<f64> {
687        let res = real_constant_fold(self, resolver, expr)?;
688        self.0[expr].contents = RealExpression::Literal(res);
689        Some(res)
690    }
691
692    fn int_constant_fold(
693        &mut self,
694        resolver: &mut impl ConstantResolver,
695        expr: IntegerExpressionId,
696    ) -> Option<i64> {
697        let res = int_constant_fold(self, resolver, expr)?;
698        self.0[expr].contents = IntegerExpression::Literal(res);
699        Some(res)
700    }
701
702    fn string_constant_fold(
703        &mut self,
704        resolver: &mut impl ConstantResolver,
705        expr: StringExpressionId,
706    ) -> Option<StringLiteral> {
707        let res = string_constant_fold(self, resolver, expr)?;
708        self.0[expr].contents = StringExpression::Literal(res);
709        Some(res)
710    }
711
712    fn resolve_to_string_subexpressions(
713        &mut self,
714        dst: StringExpressionId,
715        newval: StringExpressionId,
716    ) {
717        self.0[dst] = self.0[newval].clone();
718    }
719
720    fn resolve_to_int_subexpressions(
721        &mut self,
722        dst: IntegerExpressionId,
723        newval: IntegerExpressionId,
724    ) {
725        self.0[dst] = self.0[newval].clone();
726    }
727
728    fn resolve_to_real_subexpressions(&mut self, dst: RealExpressionId, newval: RealExpressionId) {
729        self.0[dst] = self.0[newval].clone();
730    }
731
732    fn mir(&self) -> &Mir {
733        &self.0
734    }
735}
736
737pub trait ConstantFolder: Sized {
738    // We do defaults with external functions because specialization is nowhere near stable
739    // This is similar to overwriting methods in OOP
740
741    fn real_constant_fold(
742        &mut self,
743        resolver: &mut impl ConstantResolver,
744        expr: RealExpressionId,
745    ) -> Option<f64> {
746        real_constant_fold(self, resolver, expr)
747    }
748
749    fn int_constant_fold(
750        &mut self,
751        resolver: &mut impl ConstantResolver,
752        expr: IntegerExpressionId,
753    ) -> Option<i64> {
754        int_constant_fold(self, resolver, expr)
755    }
756
757    fn string_constant_fold(
758        &mut self,
759        resolver: &mut impl ConstantResolver,
760        expr: StringExpressionId,
761    ) -> Option<StringLiteral> {
762        string_constant_fold(self, resolver, expr)
763    }
764
765    fn resolve_to_string_subexpressions(
766        &mut self,
767        dst: StringExpressionId,
768        newval: StringExpressionId,
769    );
770    fn resolve_to_int_subexpressions(
771        &mut self,
772        dst: IntegerExpressionId,
773        newval: IntegerExpressionId,
774    );
775    fn resolve_to_real_subexpressions(&mut self, dst: RealExpressionId, newval: RealExpressionId);
776
777    fn mir(&self) -> &Mir;
778}
779
780impl ControlFlowGraph {
781    pub fn constant_fold(
782        &mut self,
783        fold: &mut impl ConstantFolder,
784        udg: &mut UseDefGraph,
785        known_values: &mut ConstantFoldState,
786        write_intermediate: bool,
787    ) {
788        let mut temporary_set = BitSet::new_empty(udg.len_stmd_idx());
789
790        for (id, bb) in self.reverse_postorder_itermut() {
791            Self::constant_fold_basic_block(
792                id,
793                bb,
794                fold,
795                udg,
796                known_values,
797                &mut temporary_set,
798                write_intermediate,
799            );
800
801            if let Terminator::Split {
802                condition,
803                true_block,
804                false_block,
805                merge,
806            } = bb.terminator
807            {
808                temporary_set
809                    .as_mut_slice()
810                    .copy_from_slice(udg.terminator_use_def_chains[id].as_slice());
811
812                let folded_condition = fold.int_constant_fold(
813                    &mut ConstantPropagator {
814                        known_values,
815                        dependencys_before: &udg.terminator_use_def_chains[id],
816                        dependencys_after: &mut temporary_set,
817                        variables_assignments: &udg.assignments,
818                    },
819                    condition,
820                );
821
822                if write_intermediate || folded_condition.is_some() {
823                    std::mem::swap(&mut temporary_set, &mut udg.terminator_use_def_chains[id]);
824                }
825
826                match folded_condition {
827                    Some(0) => {
828                        debug!(
829                            "{:?}->(false: {:?}, true: {:?}) always false (condition: {:?})",
830                            id, false_block, true_block, condition
831                        );
832
833                        bb.terminator = Terminator::Goto(false_block);
834                    }
835
836                    Some(_) => {
837                        if merge == id {
838                            panic!("Found constant infinite loop!")
839                        }
840
841                        bb.terminator = Terminator::Goto(true_block);
842                    }
843
844                    None => (),
845                }
846            }
847        }
848    }
849
850    fn constant_fold_basic_block(
851        id: BasicBlockId,
852        bb: &mut BasicBlock,
853        fold: &mut impl ConstantFolder,
854        udg: &mut UseDefGraph,
855        known_values: &mut ConstantFoldState,
856        temporary_set: &mut BitSet<StatementId>,
857        write_intermediate: bool,
858    ) {
859        for &stmt in bb.statements.iter().rev() {
860            temporary_set
861                .as_mut_slice()
862                .copy_from_slice(udg.stmt_use_def_chains[stmt].as_slice());
863
864            let mut resolver = ConstantPropagator {
865                known_values,
866                dependencys_before: &udg.stmt_use_def_chains[stmt],
867                dependencys_after: temporary_set,
868                variables_assignments: &udg.assignments,
869            };
870
871            match fold.mir()[stmt] {
872                Statement::Assignment(_, _, val) => {
873                    let success = match val {
874                        ExpressionId::Real(val) => {
875                            if let Some(val) = fold.real_constant_fold(&mut resolver, val) {
876                                let old = known_values.real_definitions.insert(stmt, val);
877                                #[cfg(debug_assertions)]
878                                match old{
879                                    Some(new) if new != val => panic!(
880                                        "Statement {} in block {} was assigned twice with different values (old={},new={})!",
881                                        stmt,
882                                        id,
883                                        old.unwrap(),
884                                        val
885                                    ),
886                                    _ => ()
887                                }
888                                true
889                            } else {
890                                false
891                            }
892                        }
893
894                        ExpressionId::Integer(val) => {
895                            if let Some(val) = fold.int_constant_fold(&mut resolver, val) {
896                                let old = known_values.integer_definitions.insert(stmt, val);
897                                #[cfg(debug_assertions)]
898                                match old{
899                                    Some(new) if new != val => panic!(
900                                        "Statement {} in block {} was assigned twice with different values (old={},new={})!",
901                                        stmt,
902                                        id,
903                                        old.unwrap(),
904                                        val
905                                    ),
906                                    _ => ()
907                                }
908                                true
909                            } else {
910                                false
911                            }
912                        }
913                        ExpressionId::String(val) => {
914                            if let Some(val) = fold.string_constant_fold(&mut resolver, val) {
915                                let old = known_values.string_definitions.insert(stmt, val);
916                                #[cfg(debug_assertions)]
917                                match old{
918                                    Some(old) if old != val => panic!(
919                                        "Statement {} in block {} was assigned twice with different values (old={},new={})!",
920                                        stmt,
921                                        id,
922                                        old,
923                                        val
924                                    ),
925                                    _ => ()
926                                }
927                                true
928                            } else {
929                                false
930                            }
931                        }
932                    };
933
934                    if success || write_intermediate {
935                        std::mem::swap(temporary_set, &mut udg.stmt_use_def_chains[stmt])
936                    }
937                }
938
939                Statement::Contribute(_, _, _, val) if write_intermediate => {
940                    fold.real_constant_fold(&mut resolver, val);
941                    std::mem::swap(temporary_set, &mut udg.stmt_use_def_chains[stmt])
942                }
943
944                _ => (),
945            }
946        }
947    }
948}