Skip to main content

litex/infer/
infer_equal_and_normal.rs

1use crate::prelude::*;
2use crate::verify::{compare_normalized_number_str_to_zero, NumberCompareResult};
3
4fn obj_is_infer_literal_zero(obj: &Obj) -> bool {
5    match obj {
6        Obj::Number(n) => matches!(
7            compare_normalized_number_str_to_zero(&n.normalized_value),
8            NumberCompareResult::Equal
9        ),
10        _ => false,
11    }
12}
13
14impl Runtime {
15    fn store_inferred_fact_and_record_result(
16        &mut self,
17        inferred_fact: Fact,
18        equal_fact: &EqualFact,
19        infer_result: &mut InferResult,
20        infer_step_description: &str,
21    ) -> Result<(), RuntimeError> {
22        infer_result.new_fact(&inferred_fact);
23        self.verify_well_defined_and_store_and_infer_with_default_verify_state(inferred_fact)
24            .map_err(|previous_error| {
25                RuntimeError::from(InferRuntimeError(RuntimeErrorStruct::new(
26                    None,
27                    format!(
28                        "failed to store inferred {} while inferring `{}`",
29                        infer_step_description, equal_fact
30                    ),
31                    equal_fact.line_file.clone(),
32                    Some(previous_error),
33                    vec![],
34                )))
35            })?;
36        Ok(())
37    }
38
39    fn infer_equal_fact_cart_from_known_side(
40        &mut self,
41        known_cart_obj: &Cart,
42        known_cart_obj_as_symbol: &Obj,
43        target_obj: &Obj,
44        equal_fact: &EqualFact,
45        infer_result: &mut InferResult,
46    ) -> Result<(), RuntimeError> {
47        let target_is_cart_fact =
48            IsCartFact::new(target_obj.clone(), equal_fact.line_file.clone()).into();
49        self.store_inferred_fact_and_record_result(
50            target_is_cart_fact,
51            equal_fact,
52            infer_result,
53            "cart fact",
54        )?;
55
56        let target_cart_dim_obj = CartDim::new(target_obj.clone()).into();
57        let known_cart_dim_obj = Number::new(known_cart_obj.args.len().to_string()).into();
58        let cart_dim_equal_fact = EqualFact::new(
59            target_cart_dim_obj,
60            known_cart_dim_obj,
61            equal_fact.line_file.clone(),
62        )
63        .into();
64        self.store_inferred_fact_and_record_result(
65            cart_dim_equal_fact,
66            equal_fact,
67            infer_result,
68            "cart_dim fact",
69        )?;
70        self.store_known_cart_obj(
71            &known_cart_obj_as_symbol.to_string(),
72            known_cart_obj.clone(),
73            equal_fact.line_file.clone(),
74        );
75        self.store_known_cart_obj(
76            &target_obj.to_string(),
77            known_cart_obj.clone(),
78            equal_fact.line_file.clone(),
79        );
80        Ok(())
81    }
82
83    fn infer_equal_fact_tuple_from_known_side(
84        &mut self,
85        known_tuple_obj: &Tuple,
86        target_obj: &Obj,
87        equal_fact: &EqualFact,
88        infer_result: &mut InferResult,
89    ) -> Result<(), RuntimeError> {
90        if known_tuple_obj.args.len() < 2 {
91            return Ok(());
92        }
93        let target_is_tuple_fact =
94            IsTupleFact::new(target_obj.clone(), equal_fact.line_file.clone()).into();
95        self.store_inferred_fact_and_record_result(
96            target_is_tuple_fact,
97            equal_fact,
98            infer_result,
99            "tuple fact",
100        )?;
101
102        let target_tuple_dim_obj = TupleDim::new(target_obj.clone()).into();
103        let known_tuple_dim_obj = Number::new(known_tuple_obj.args.len().to_string()).into();
104        let tuple_dim_equal_fact = EqualFact::new(
105            target_tuple_dim_obj,
106            known_tuple_dim_obj,
107            equal_fact.line_file.clone(),
108        )
109        .into();
110        self.store_inferred_fact_and_record_result(
111            tuple_dim_equal_fact,
112            equal_fact,
113            infer_result,
114            "tuple_dim fact",
115        )?;
116
117        self.store_tuple_obj_and_cart(
118            &target_obj.to_string(),
119            Some(known_tuple_obj.clone()),
120            None,
121            equal_fact.line_file.clone(),
122        );
123        Ok(())
124    }
125
126    fn infer_equal_fact_finite_seq_list_from_known_side(
127        &mut self,
128        known_list: &FiniteSeqListObj,
129        target_obj: &Obj,
130        equal_fact: &EqualFact,
131    ) -> Result<(), RuntimeError> {
132        let lf = equal_fact.line_file.clone();
133        self.store_known_finite_seq_list_obj(&target_obj.to_string(), known_list.clone(), None, lf);
134        Ok(())
135    }
136
137    fn infer_equal_fact_matrix_list_from_known_side(
138        &mut self,
139        known_matrix: &MatrixListObj,
140        target_obj: &Obj,
141        equal_fact: &EqualFact,
142    ) -> Result<(), RuntimeError> {
143        let lf = equal_fact.line_file.clone();
144        self.store_known_matrix_list_obj(&target_obj.to_string(), known_matrix.clone(), None, lf);
145        Ok(())
146    }
147
148    fn infer_equal_fact_by_finite_seq_list(
149        &mut self,
150        equal_fact: &EqualFact,
151    ) -> Result<InferResult, RuntimeError> {
152        let infer_result = InferResult::new();
153
154        if let Obj::FiniteSeqListObj(list) = &equal_fact.left {
155            if !matches!(&equal_fact.right, Obj::FiniteSeqListObj(_)) {
156                self.infer_equal_fact_finite_seq_list_from_known_side(
157                    list,
158                    &equal_fact.right,
159                    equal_fact,
160                )?;
161            }
162        }
163
164        if let Obj::FiniteSeqListObj(list) = &equal_fact.right {
165            if !matches!(&equal_fact.left, Obj::FiniteSeqListObj(_)) {
166                self.infer_equal_fact_finite_seq_list_from_known_side(
167                    list,
168                    &equal_fact.left,
169                    equal_fact,
170                )?;
171            }
172        }
173
174        Ok(infer_result)
175    }
176
177    fn infer_equal_fact_by_matrix_list(
178        &mut self,
179        equal_fact: &EqualFact,
180    ) -> Result<InferResult, RuntimeError> {
181        let infer_result = InferResult::new();
182
183        if let Obj::MatrixListObj(m) = &equal_fact.left {
184            if !matches!(&equal_fact.right, Obj::MatrixListObj(_)) {
185                self.infer_equal_fact_matrix_list_from_known_side(
186                    m,
187                    &equal_fact.right,
188                    equal_fact,
189                )?;
190            }
191        }
192
193        if let Obj::MatrixListObj(m) = &equal_fact.right {
194            if !matches!(&equal_fact.left, Obj::MatrixListObj(_)) {
195                self.infer_equal_fact_matrix_list_from_known_side(m, &equal_fact.left, equal_fact)?;
196            }
197        }
198
199        Ok(infer_result)
200    }
201
202    // From `u = v`: merge numeric normal forms in the env; if one side is `a-b` and the other `0`, emit `a=b`;
203    // if one side is a literal cart/tuple/finite-seq/matrix list, record shape for the other symbol.
204    // Example: `a = 1+2` binds `a` to normalized `3`; `0 = x-y` yields fact `x = y`.
205    pub fn infer_equal_fact(
206        &mut self,
207        equal_fact: &EqualFact,
208    ) -> Result<InferResult, RuntimeError> {
209        let mut infer_result = InferResult::new();
210        infer_result.new_infer_result_inside(
211            self.infer_equal_fact_from_subtraction_equals_zero(equal_fact)?,
212        );
213        infer_result
214            .new_infer_result_inside(self.infer_equal_fact_and_give_value_to_obj(equal_fact)?);
215        infer_result.new_infer_result_inside(self.infer_equal_fact_by_cart(equal_fact)?);
216        infer_result.new_infer_result_inside(self.infer_equal_fact_by_tuple(equal_fact)?);
217        infer_result.new_infer_result_inside(self.infer_equal_fact_by_finite_seq_list(equal_fact)?);
218        infer_result.new_infer_result_inside(self.infer_equal_fact_by_matrix_list(equal_fact)?);
219        infer_result.new_infer_result_inside(self.infer_equal_fact_by_anonymous_fn(equal_fact)?);
220
221        Ok(infer_result)
222    }
223
224    /// `name = '(... ) ... { ... }'`: treat `name` as having the anonymous function's `FnSetBody`
225    /// (same side table as `name $in fn ...` after infer).
226    fn infer_equal_fact_by_anonymous_fn(
227        &mut self,
228        equal_fact: &EqualFact,
229    ) -> Result<InferResult, RuntimeError> {
230        if let Obj::AnonymousFn(anon) = &equal_fact.right {
231            if !matches!(&equal_fact.left, Obj::AnonymousFn(_)) {
232                let eq = (*anon.equal_to).clone();
233                let lf = equal_fact.line_file.clone();
234                self.register_known_objs_in_fn_sets_for_element_body(
235                    &equal_fact.left,
236                    anon.body.clone(),
237                    Some(eq),
238                    lf.clone(),
239                    lf,
240                );
241            }
242        }
243        if let Obj::AnonymousFn(anon) = &equal_fact.left {
244            if !matches!(&equal_fact.right, Obj::AnonymousFn(_)) {
245                let eq = (*anon.equal_to).clone();
246                let lf = equal_fact.line_file.clone();
247                self.register_known_objs_in_fn_sets_for_element_body(
248                    &equal_fact.right,
249                    anon.body.clone(),
250                    Some(eq),
251                    lf.clone(),
252                    lf,
253                );
254            }
255        }
256        Ok(InferResult::new())
257    }
258
259    // `0 = u - v` or `u - v = 0` => add `u = v` (non-trivial pair only).
260    fn infer_equal_fact_from_subtraction_equals_zero(
261        &mut self,
262        equal_fact: &EqualFact,
263    ) -> Result<InferResult, RuntimeError> {
264        let mut infer_result = InferResult::new();
265        let (a, b) = if obj_is_infer_literal_zero(&equal_fact.left) {
266            match &equal_fact.right {
267                Obj::Sub(s) => (s.left.as_ref().clone(), s.right.as_ref().clone()),
268                _ => return Ok(infer_result),
269            }
270        } else if obj_is_infer_literal_zero(&equal_fact.right) {
271            match &equal_fact.left {
272                Obj::Sub(s) => (s.left.as_ref().clone(), s.right.as_ref().clone()),
273                _ => return Ok(infer_result),
274            }
275        } else {
276            return Ok(infer_result);
277        };
278        if a.to_string() == b.to_string() {
279            return Ok(infer_result);
280        }
281        let derived: Fact = EqualFact::new(a, b, equal_fact.line_file.clone()).into();
282        self.store_inferred_fact_and_record_result(
283            derived,
284            equal_fact,
285            &mut infer_result,
286            "equality from a - b = 0",
287        )?;
288        Ok(infer_result)
289    }
290
291    fn infer_equal_fact_by_cart(
292        &mut self,
293        equal_fact: &EqualFact,
294    ) -> Result<InferResult, RuntimeError> {
295        let mut infer_result = InferResult::new();
296
297        if let Obj::Cart(cart) = &equal_fact.left {
298            self.infer_equal_fact_cart_from_known_side(
299                cart,
300                &equal_fact.left,
301                &equal_fact.right,
302                equal_fact,
303                &mut infer_result,
304            )?;
305        }
306
307        if let Obj::Cart(cart) = &equal_fact.right {
308            self.infer_equal_fact_cart_from_known_side(
309                cart,
310                &equal_fact.right,
311                &equal_fact.left,
312                equal_fact,
313                &mut infer_result,
314            )?;
315        }
316
317        Ok(infer_result)
318    }
319
320    fn infer_equal_fact_by_tuple(
321        &mut self,
322        equal_fact: &EqualFact,
323    ) -> Result<InferResult, RuntimeError> {
324        let mut infer_result = InferResult::new();
325
326        if let Obj::Tuple(tuple) = &equal_fact.left {
327            self.infer_equal_fact_tuple_from_known_side(
328                tuple,
329                &equal_fact.right,
330                equal_fact,
331                &mut infer_result,
332            )?;
333        }
334
335        if let Obj::Tuple(tuple) = &equal_fact.right {
336            self.infer_equal_fact_tuple_from_known_side(
337                tuple,
338                &equal_fact.left,
339                equal_fact,
340                &mut infer_result,
341            )?;
342        }
343
344        Ok(infer_result)
345    }
346
347    fn infer_equal_fact_and_give_value_to_obj(
348        &mut self,
349        equal_fact: &EqualFact,
350    ) -> Result<InferResult, RuntimeError> {
351        if let Some(right_calculated_value) = self.resolve_obj_to_number(&equal_fact.right) {
352            self.top_level_env()
353                .known_objs_equal_to_normalized_decimal_number
354                .insert(equal_fact.left.to_string(), right_calculated_value);
355        }
356
357        if let Some(left_calculated_value) = self.resolve_obj_to_number(&equal_fact.left) {
358            self.top_level_env()
359                .known_objs_equal_to_normalized_decimal_number
360                .insert(equal_fact.right.to_string(), left_calculated_value);
361        }
362
363        if let Some(derived) =
364            crate::environment::equality_linear_derive::maybe_derived_linear_equal_fact(equal_fact)
365        {
366            if let Some(n) = self.resolve_obj_to_number(&derived.right) {
367                self.top_level_env()
368                    .known_objs_equal_to_normalized_decimal_number
369                    .insert(derived.left.to_string(), n);
370            }
371        }
372
373        Ok(InferResult::new())
374    }
375
376    // Predicate `P(args)`: check args against `P`'s param types, then store each instantiated `iff` body.
377    // Example: if `P` is defined by `iff` clauses, those clauses become facts with `args` substituted.
378    pub fn infer_normal_atomic_fact(
379        &mut self,
380        normal_atomic_fact: &NormalAtomicFact,
381    ) -> Result<InferResult, RuntimeError> {
382        let predicate_name = normal_atomic_fact.predicate.to_string();
383        let predicate_definition = match self.get_prop_definition_by_name(&predicate_name) {
384            Some(predicate_definition) => predicate_definition.clone(),
385            None => return Ok(InferResult::new()),
386        };
387        let mut infer_result = InferResult::new();
388
389        let param_type_infer = self
390            .store_args_satisfy_param_type_when_not_defining_new_identifiers(
391                &predicate_definition.params_def_with_type,
392                &normal_atomic_fact.body,
393                normal_atomic_fact.line_file.clone(),
394                ParamObjType::DefHeader,
395            )
396            .map_err(|previous_error| {
397                RuntimeError::from(InferRuntimeError(RuntimeErrorStruct::new(
398                    None,
399                    format!(
400                        "failed to verify parameter types for `{}`",
401                        normal_atomic_fact
402                    ),
403                    normal_atomic_fact.line_file.clone(),
404                    Some(previous_error),
405                    vec![],
406                )))
407            })?;
408        infer_result.new_infer_result_inside(param_type_infer);
409
410        let param_to_arg_map = self.params_to_arg_map(
411            &predicate_definition.params_def_with_type,
412            &normal_atomic_fact.body,
413        )?;
414
415        for iff_fact in predicate_definition.iff_facts.iter() {
416            let instantiated_iff_fact = self
417                .inst_fact(
418                    iff_fact,
419                    &param_to_arg_map,
420                    ParamObjType::DefHeader,
421                    Some(normal_atomic_fact.line_file.clone()),
422                )
423                .map_err(|e| {
424                    RuntimeError::from(InferRuntimeError(RuntimeErrorStruct::new(
425                        None,
426                        format!(
427                            "failed to instantiate iff fact while inferring `{}`",
428                            normal_atomic_fact
429                        ),
430                        normal_atomic_fact.line_file.clone(),
431                        Some(e),
432                        vec![],
433                    )))
434                })?;
435            let fact_to_store = instantiated_iff_fact;
436            infer_result.new_fact(&fact_to_store);
437            self.verify_well_defined_and_store_and_infer_with_default_verify_state(fact_to_store)
438                .map_err(|previous_error| {
439                    RuntimeError::from(InferRuntimeError(RuntimeErrorStruct::new(
440                        None,
441                        format!(
442                            "failed to store instantiated iff fact while inferring `{}`",
443                            normal_atomic_fact
444                        ),
445                        normal_atomic_fact.line_file.clone(),
446                        Some(previous_error),
447                        vec![],
448                    )))
449                })?;
450        }
451
452        Ok(infer_result)
453    }
454}