Skip to main content

tract_data/dim/
sym.rs

1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18#[derive(Clone, Default)]
19pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
20
21impl PartialEq for SymbolScope {
22    fn eq(&self, other: &Self) -> bool {
23        Arc::ptr_eq(&self.0, &other.0)
24    }
25}
26
27impl Eq for SymbolScope {}
28
29pub struct SymbolScopeData {
30    id: usize,
31    table: DefaultStringInterner,
32    assertions: Vec<Assertion>,
33    scenarios: Vec<(String, Vec<Assertion>)>,
34}
35
36impl Default for SymbolScopeData {
37    fn default() -> Self {
38        SymbolScopeData {
39            id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
40            table: DefaultStringInterner::default(),
41            assertions: Vec::new(),
42            scenarios: Vec::new(),
43        }
44    }
45}
46
47impl SymbolScope {
48    pub fn id(&self) -> usize {
49        let locked = self.0.lock();
50        let locked = locked.borrow();
51        locked.id
52    }
53
54    pub fn proof_cache_session(&self) -> ProofCacheSession {
55        ProofCacheSession::new(self.id())
56    }
57
58    pub fn get(&self, name: &str) -> Option<Symbol> {
59        let locked = self.0.lock();
60        let locked = locked.borrow();
61        locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
62    }
63
64    /// Get or create the coordinate symbol for axis `k` (named "🎯{k}").
65    pub fn coord_sym(&self, k: usize) -> Symbol {
66        self.sym(&format!("🎯{k}"))
67    }
68
69    pub fn sym(&self, name: &str) -> Symbol {
70        let locked = self.0.lock();
71        let mut locked = locked.borrow_mut();
72        let sym = locked.table.get_or_intern(name);
73        Symbol(Arc::downgrade(&self.0), sym)
74    }
75
76    pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
77        let locked = self.0.lock();
78        let mut locked = locked.borrow_mut();
79        let sym = if locked.table.get(prefix).is_none() {
80            locked.table.get_or_intern(prefix)
81        } else {
82            let mut i = 0;
83            loop {
84                let s = format!("{prefix}_{i}");
85                if locked.table.get(&s).is_none() {
86                    break locked.table.get_or_intern(s);
87                }
88                i += 1;
89            }
90        };
91        Symbol(Arc::downgrade(&self.0), sym)
92    }
93
94    pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
95        parse_tdim(self, input.as_ref())
96    }
97
98    pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
99        let assert = assert.into();
100        let assert = parse_assertion(self, &assert)?;
101        let locked = self.0.lock();
102        let mut locked = locked.borrow_mut();
103        locked.assertions.push(assert);
104        Ok(())
105    }
106
107    pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
108        self.add_assertion(assert)?;
109        Ok(self)
110    }
111
112    pub fn all_assertions(&self) -> Vec<Assertion> {
113        let locked = self.0.lock();
114        let locked = locked.borrow();
115        locked.assertions.clone()
116    }
117
118    pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
119        let locked = self.0.lock();
120        let locked = locked.borrow();
121        locked.scenarios.clone()
122    }
123
124    pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
125        let locked = self.0.lock();
126        let mut locked = locked.borrow_mut();
127        let s = scenario.into();
128        if !locked.scenarios.iter().any(|sc| sc.0 == s) {
129            locked.scenarios.push((s, vec![]));
130        }
131        Ok(())
132    }
133
134    pub fn add_scenario_assertion(
135        &self,
136        scenario: impl Into<String>,
137        assertion: impl Into<String>,
138    ) -> TractResult<()> {
139        let assert = parse_assertion(self, &assertion.into())?;
140        let s = scenario.into();
141        let locked = self.0.lock();
142        let mut locked = locked.borrow_mut();
143        if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
144            s.1.push(assert);
145        } else {
146            locked.scenarios.push((s, vec![assert]));
147        }
148        Ok(())
149    }
150
151    pub fn with_scenario_assertion(
152        self,
153        scenario: impl Into<String>,
154        assertion: impl Into<String>,
155    ) -> TractResult<Self> {
156        self.add_scenario_assertion(scenario, assertion)?;
157        Ok(self)
158    }
159
160    pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
161        self.add_scenario(scenario)?;
162        Ok(self)
163    }
164
165    pub fn all_symbols(&self) -> Vec<Symbol> {
166        self.0
167            .lock()
168            .borrow()
169            .table
170            .into_iter()
171            .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
172            .collect()
173    }
174
175    pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
176        let locked = self.0.lock();
177        let locked = locked.borrow();
178        if locked.scenarios.len() == 0 {
179            return Ok(None);
180        }
181        let mut maybe = None;
182        for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
183            if assertions.iter().any(|a| a.check(values) == Some(false)) {
184                continue;
185            } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
186                return Ok(Some(ix));
187            } else if maybe.is_none() {
188                maybe = Some(ix);
189            } else {
190                return Ok(None);
191            }
192        }
193        if maybe.is_some() {
194            Ok(maybe)
195        } else {
196            anyhow::bail!("No possible scenario");
197        }
198    }
199}
200
201thread_local! {
202    static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
203}
204
205struct ProofCache {
206    scope_id: usize,
207    depth: usize,
208    cache: HashMap<TDim, bool>,
209}
210
211pub struct ProofCacheSession {
212    active: bool,
213}
214
215impl ProofCacheSession {
216    pub fn new(scope_id: usize) -> Self {
217        let active = PROOF_CACHE.with(|cell| {
218            let mut borrow = cell.borrow_mut();
219            match &mut *borrow {
220                None => {
221                    *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
222                    true
223                }
224                Some(pc) if pc.scope_id == scope_id => {
225                    pc.depth += 1;
226                    true
227                }
228                Some(_) => false,
229            }
230        });
231        ProofCacheSession { active }
232    }
233}
234
235impl Drop for ProofCacheSession {
236    fn drop(&mut self) {
237        if !self.active {
238            return;
239        }
240        PROOF_CACHE.with(|cell| {
241            let mut borrow = cell.borrow_mut();
242            if let Some(pc) = &mut *borrow {
243                pc.depth -= 1;
244                if pc.depth == 0 {
245                    *borrow = None;
246                }
247            }
248        });
249    }
250}
251
252impl SymbolScopeData {
253    pub fn all_assertions(&self) -> &[Assertion] {
254        &self.assertions
255    }
256
257    pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
258        self.assertions.iter().chain(
259            scenario
260                .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
261                .map(|s| &*s.1)
262                .unwrap_or(&[])
263                .iter(),
264        )
265    }
266
267    pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
268        self.scenarios.iter().map(|s| &*s.0)
269    }
270
271    pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
272        self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
273    }
274
275    pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
276        self.table.resolve(sym.1).map(f)
277    }
278
279    #[allow(clippy::mutable_key_type)]
280    pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
281        if let TDim::Val(v) = t {
282            return *v >= 0;
283        }
284        let cached = PROOF_CACHE.with(|cell| {
285            let borrow = cell.borrow();
286            if let Some(pc) = &*borrow {
287                debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
288                pc.cache.get(t).copied()
289            } else {
290                None
291            }
292        });
293        if let Some(result) = cached {
294            return result;
295        }
296        let result = self.prove_positive_or_zero_inner(t);
297        PROOF_CACHE.with(|cell| {
298            let mut borrow = cell.borrow_mut();
299            if let Some(pc) = &mut *borrow {
300                pc.cache.insert(t.clone(), result);
301            }
302        });
303        result
304    }
305
306    #[allow(clippy::mutable_key_type)]
307    fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
308        self.prove_positive_or_zero_inner_with_extra(t, &[])
309    }
310
311    #[allow(clippy::mutable_key_type)]
312    fn prove_positive_or_zero_inner_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
313        let positives = self
314            .assertions
315            .iter()
316            .chain(extra.iter())
317            .filter_map(|i| i.as_known_positive())
318            .collect_vec();
319        let mut visited = vec![];
320        let mut todo = vec![t.clone()];
321        while let Some(t) = todo.pop() {
322            if t.to_i64().is_ok_and(|i| i >= 0) {
323                return true;
324            }
325            if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
326                return true;
327            }
328            // Div(a, q) with q >= 1 is non-negative whenever a is non-negative.
329            if let TDim::Div(a, q) = &t {
330                if *q >= 1 && self.prove_positive_or_zero_inner_with_extra(a, extra) {
331                    return true;
332                }
333            }
334            let syms = t.symbols();
335            for s in syms {
336                let me = t.guess_slope(&s);
337                for pos in &positives {
338                    if pos.symbols().contains(&s) {
339                        let other = pos.guess_slope(&s);
340                        if me.0.signum() == other.0.signum() {
341                            let new = t.clone() * me.1 * other.0.abs()
342                                - pos.clone() * me.0.abs() * other.1;
343                            if !visited.contains(&new) {
344                                todo.push(new);
345                            }
346                        }
347                    }
348                }
349            }
350            visited.push(t);
351            if visited.len() > 10 {
352                break;
353            }
354        }
355        false
356    }
357
358    pub(crate) fn prove_positive_or_zero_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
359        if let TDim::Val(v) = t {
360            return *v >= 0;
361        }
362        // Skip the proof cache for extra-assertion calls (cache is keyed without extra context)
363        self.prove_positive_or_zero_inner_with_extra(t, extra)
364    }
365
366    pub(crate) fn prove_strict_positive_with_extra(&self, b: &TDim, extra: &[Assertion]) -> bool {
367        self.prove_positive_or_zero_with_extra(&(b.clone() - 1), extra)
368    }
369}
370
371impl fmt::Debug for SymbolScope {
372    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373        let locked = self.0.lock();
374        let locked = locked.borrow();
375        write!(
376            f,
377            "symbols: {}; assertions: {}; {}",
378            locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
379            locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
380            locked
381                .scenarios
382                .iter()
383                .map(|s| format!(
384                    "{}: {}",
385                    s.0,
386                    s.1.iter().map(|s| s.to_string()).sorted().join(", ")
387                ))
388                .join(" ; "),
389        )
390    }
391}
392
393#[derive(Clone)]
394pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
395
396impl Eq for Symbol {}
397
398impl PartialEq for Symbol {
399    fn eq(&self, other: &Self) -> bool {
400        self.1 == other.1
401    }
402}
403
404impl Symbol {
405    pub fn scope(&self) -> Option<SymbolScope> {
406        self.0.upgrade().map(SymbolScope)
407    }
408}
409
410impl PartialOrd for Symbol {
411    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
412        Some(self.cmp(other))
413    }
414}
415
416impl Ord for Symbol {
417    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
418        self.1.cmp(&other.1)
419    }
420}
421
422impl std::hash::Hash for Symbol {
423    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
424        self.1.hash(state)
425    }
426}
427
428impl std::fmt::Display for Symbol {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        if let Some(scope) = self.scope() {
431            let lock = scope.0.lock();
432            let lock = lock.borrow();
433            if let Some(s) = lock.table.resolve(self.1) {
434                return write!(f, "{s}");
435            }
436        }
437        write!(f, "<Sym{}>", self.1.to_usize())
438    }
439}
440
441impl fmt::Debug for Symbol {
442    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443        Display::fmt(&self, f)
444    }
445}
446
447#[derive(Clone, Debug, Default)]
448pub struct SymbolValues {
449    values: HashMap<Symbol, i64>,
450}
451
452impl SymbolValues {
453    pub fn with(mut self, s: &Symbol, v: i64) -> Self {
454        self.set(s, v);
455        self
456    }
457
458    pub fn set(&mut self, s: &Symbol, v: i64) {
459        self.values.insert(s.clone(), v);
460    }
461
462    pub fn get(&self, s: &Symbol) -> Option<i64> {
463        self.values.get(s).copied()
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    #[test]
472    fn as_known_positive_gte() {
473        let s = SymbolScope::default();
474        assert_eq!(
475            parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
476            Some(s.parse_tdim("S").unwrap())
477        );
478    }
479
480    #[test]
481    fn as_known_positive_gt() {
482        let s = SymbolScope::default();
483        assert_eq!(
484            parse_assertion(&s, "S>0").unwrap().as_known_positive(),
485            Some(s.parse_tdim("S-1").unwrap())
486        );
487    }
488
489    #[test]
490    fn as_known_positive_lte() {
491        let s = SymbolScope::default();
492        assert_eq!(
493            parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
494            Some(s.parse_tdim("-S").unwrap())
495        );
496    }
497
498    #[test]
499    fn as_known_positive_lt() {
500        let s = SymbolScope::default();
501        assert_eq!(
502            parse_assertion(&s, "S<0").unwrap().as_known_positive(),
503            Some(s.parse_tdim("-S - 1").unwrap())
504        );
505    }
506
507    #[test]
508    fn prove_positive_0() {
509        let s = SymbolScope::default();
510        assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
511    }
512
513    #[test]
514    fn prove_positive_1() {
515        let s = SymbolScope::default();
516        assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
517    }
518
519    #[test]
520    fn prove_positive_neg1() {
521        let s = SymbolScope::default();
522        assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
523    }
524
525    #[test]
526    fn prove_positive_add_0() {
527        let s = SymbolScope::default();
528        assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
529    }
530}