Skip to main content

tract_data/dim/
sym.rs

1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18#[derive(Clone, Default)]
19pub struct SymbolScope(pub Arc<ReentrantMutex<RefCell<SymbolScopeData>>>);
20
21impl PartialEq for SymbolScope {
22    fn eq(&self, other: &Self) -> bool {
23        Arc::ptr_eq(&self.0, &other.0)
24    }
25}
26
27impl Eq for SymbolScope {}
28
29pub struct SymbolScopeData {
30    id: usize,
31    table: DefaultStringInterner,
32    assertions: Vec<Assertion>,
33    scenarios: Vec<(String, Vec<Assertion>)>,
34}
35
36impl Default for SymbolScopeData {
37    fn default() -> Self {
38        SymbolScopeData {
39            id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
40            table: DefaultStringInterner::default(),
41            assertions: Vec::new(),
42            scenarios: Vec::new(),
43        }
44    }
45}
46
47impl SymbolScope {
48    pub fn id(&self) -> usize {
49        let locked = self.0.lock();
50        let locked = locked.borrow();
51        locked.id
52    }
53
54    pub fn proof_cache_session(&self) -> ProofCacheSession {
55        ProofCacheSession::new(self.id())
56    }
57
58    pub fn get(&self, name: &str) -> Option<Symbol> {
59        let locked = self.0.lock();
60        let locked = locked.borrow();
61        locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
62    }
63
64    pub fn sym(&self, name: &str) -> Symbol {
65        let locked = self.0.lock();
66        let mut locked = locked.borrow_mut();
67        let sym = locked.table.get_or_intern(name);
68        Symbol(Arc::downgrade(&self.0), sym)
69    }
70
71    pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
72        let locked = self.0.lock();
73        let mut locked = locked.borrow_mut();
74        let sym = if locked.table.get(prefix).is_none() {
75            locked.table.get_or_intern(prefix)
76        } else {
77            let mut i = 0;
78            loop {
79                let s = format!("{prefix}_{i}");
80                if locked.table.get(&s).is_none() {
81                    break locked.table.get_or_intern(s);
82                }
83                i += 1;
84            }
85        };
86        Symbol(Arc::downgrade(&self.0), sym)
87    }
88
89    pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
90        parse_tdim(self, input.as_ref())
91    }
92
93    pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
94        let assert = assert.into();
95        let assert = parse_assertion(self, &assert)?;
96        let locked = self.0.lock();
97        let mut locked = locked.borrow_mut();
98        locked.assertions.push(assert);
99        Ok(())
100    }
101
102    pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
103        self.add_assertion(assert)?;
104        Ok(self)
105    }
106
107    pub fn all_assertions(&self) -> Vec<Assertion> {
108        let locked = self.0.lock();
109        let locked = locked.borrow();
110        locked.assertions.clone()
111    }
112
113    pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
114        let locked = self.0.lock();
115        let locked = locked.borrow();
116        locked.scenarios.clone()
117    }
118
119    pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
120        let locked = self.0.lock();
121        let mut locked = locked.borrow_mut();
122        let s = scenario.into();
123        if !locked.scenarios.iter().any(|sc| sc.0 == s) {
124            locked.scenarios.push((s, vec![]));
125        }
126        Ok(())
127    }
128
129    pub fn add_scenario_assertion(
130        &self,
131        scenario: impl Into<String>,
132        assertion: impl Into<String>,
133    ) -> TractResult<()> {
134        let assert = parse_assertion(self, &assertion.into())?;
135        let s = scenario.into();
136        let locked = self.0.lock();
137        let mut locked = locked.borrow_mut();
138        if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
139            s.1.push(assert);
140        } else {
141            locked.scenarios.push((s, vec![assert]));
142        }
143        Ok(())
144    }
145
146    pub fn with_scenario_assertion(
147        self,
148        scenario: impl Into<String>,
149        assertion: impl Into<String>,
150    ) -> TractResult<Self> {
151        self.add_scenario_assertion(scenario, assertion)?;
152        Ok(self)
153    }
154
155    pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
156        self.add_scenario(scenario)?;
157        Ok(self)
158    }
159
160    pub fn all_symbols(&self) -> Vec<Symbol> {
161        self.0
162            .lock()
163            .borrow()
164            .table
165            .into_iter()
166            .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
167            .collect()
168    }
169
170    pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
171        let locked = self.0.lock();
172        let locked = locked.borrow();
173        if locked.scenarios.len() == 0 {
174            return Ok(None);
175        }
176        let mut maybe = None;
177        for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
178            if assertions.iter().any(|a| a.check(values) == Some(false)) {
179                continue;
180            } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
181                return Ok(Some(ix));
182            } else if maybe.is_none() {
183                maybe = Some(ix);
184            } else {
185                return Ok(None);
186            }
187        }
188        if maybe.is_some() {
189            Ok(maybe)
190        } else {
191            anyhow::bail!("No possible scenario");
192        }
193    }
194}
195
196thread_local! {
197    static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
198}
199
200struct ProofCache {
201    scope_id: usize,
202    depth: usize,
203    cache: HashMap<TDim, bool>,
204}
205
206pub struct ProofCacheSession {
207    active: bool,
208}
209
210impl ProofCacheSession {
211    pub fn new(scope_id: usize) -> Self {
212        let active = PROOF_CACHE.with(|cell| {
213            let mut borrow = cell.borrow_mut();
214            match &mut *borrow {
215                None => {
216                    *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
217                    true
218                }
219                Some(pc) if pc.scope_id == scope_id => {
220                    pc.depth += 1;
221                    true
222                }
223                Some(_) => false,
224            }
225        });
226        ProofCacheSession { active }
227    }
228}
229
230impl Drop for ProofCacheSession {
231    fn drop(&mut self) {
232        if !self.active {
233            return;
234        }
235        PROOF_CACHE.with(|cell| {
236            let mut borrow = cell.borrow_mut();
237            if let Some(pc) = &mut *borrow {
238                pc.depth -= 1;
239                if pc.depth == 0 {
240                    *borrow = None;
241                }
242            }
243        });
244    }
245}
246
247impl SymbolScopeData {
248    pub fn all_assertions(&self) -> &[Assertion] {
249        &self.assertions
250    }
251
252    pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
253        self.assertions.iter().chain(
254            scenario
255                .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
256                .map(|s| &*s.1)
257                .unwrap_or(&[])
258                .iter(),
259        )
260    }
261
262    pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
263        self.scenarios.iter().map(|s| &*s.0)
264    }
265
266    pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
267        self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
268    }
269
270    pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
271        self.table.resolve(sym.1).map(f)
272    }
273
274    #[allow(clippy::mutable_key_type)]
275    pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
276        if let TDim::Val(v) = t {
277            return *v >= 0;
278        }
279        let cached = PROOF_CACHE.with(|cell| {
280            let borrow = cell.borrow();
281            if let Some(pc) = &*borrow {
282                debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
283                pc.cache.get(t).copied()
284            } else {
285                None
286            }
287        });
288        if let Some(result) = cached {
289            return result;
290        }
291        let result = self.prove_positive_or_zero_inner(t);
292        PROOF_CACHE.with(|cell| {
293            let mut borrow = cell.borrow_mut();
294            if let Some(pc) = &mut *borrow {
295                pc.cache.insert(t.clone(), result);
296            }
297        });
298        result
299    }
300
301    #[allow(clippy::mutable_key_type)]
302    fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
303        let positives = self.assertions.iter().filter_map(|i| i.as_known_positive()).collect_vec();
304        let mut visited = vec![];
305        let mut todo = vec![t.clone()];
306        while let Some(t) = todo.pop() {
307            if t.to_i64().is_ok_and(|i| i >= 0) {
308                return true;
309            }
310            if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
311                return true;
312            }
313            let syms = t.symbols();
314            for s in syms {
315                let me = t.guess_slope(&s);
316                for pos in &positives {
317                    if pos.symbols().contains(&s) {
318                        let other = pos.guess_slope(&s);
319                        if me.0.signum() == other.0.signum() {
320                            let new = t.clone() * me.1 * other.0.abs()
321                                - pos.clone() * me.0.abs() * other.1;
322                            if !visited.contains(&new) {
323                                todo.push(new);
324                            }
325                        }
326                    }
327                }
328            }
329            visited.push(t);
330            if visited.len() > 10 {
331                break;
332            }
333        }
334        false
335    }
336
337    pub(crate) fn prove_strict_positive(&self, b: &TDim) -> bool {
338        self.prove_positive_or_zero(&(b.clone() - 1))
339    }
340}
341
342impl fmt::Debug for SymbolScope {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        let locked = self.0.lock();
345        let locked = locked.borrow();
346        write!(
347            f,
348            "symbols: {}; assertions: {}; {}",
349            locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
350            locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
351            locked
352                .scenarios
353                .iter()
354                .map(|s| format!(
355                    "{}: {}",
356                    s.0,
357                    s.1.iter().map(|s| s.to_string()).sorted().join(", ")
358                ))
359                .join(" ; "),
360        )
361    }
362}
363
364#[derive(Clone)]
365pub struct Symbol(Weak<ReentrantMutex<RefCell<SymbolScopeData>>>, string_interner::DefaultSymbol);
366
367impl Eq for Symbol {}
368
369impl PartialEq for Symbol {
370    fn eq(&self, other: &Self) -> bool {
371        self.1 == other.1
372    }
373}
374
375impl Symbol {
376    pub fn scope(&self) -> Option<SymbolScope> {
377        self.0.upgrade().map(SymbolScope)
378    }
379}
380
381impl PartialOrd for Symbol {
382    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
383        Some(self.cmp(other))
384    }
385}
386
387impl Ord for Symbol {
388    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
389        self.1.cmp(&other.1)
390    }
391}
392
393impl std::hash::Hash for Symbol {
394    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
395        self.1.hash(state)
396    }
397}
398
399impl std::fmt::Display for Symbol {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        if let Some(scope) = self.scope() {
402            let lock = scope.0.lock();
403            let lock = lock.borrow();
404            if let Some(s) = lock.table.resolve(self.1) {
405                return write!(f, "{s}");
406            }
407        }
408        write!(f, "<Sym{}>", self.1.to_usize())
409    }
410}
411
412impl fmt::Debug for Symbol {
413    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414        Display::fmt(&self, f)
415    }
416}
417
418#[derive(Clone, Debug, Default)]
419pub struct SymbolValues {
420    values: HashMap<Symbol, i64>,
421}
422
423impl SymbolValues {
424    pub fn with(mut self, s: &Symbol, v: i64) -> Self {
425        self.set(s, v);
426        self
427    }
428
429    pub fn set(&mut self, s: &Symbol, v: i64) {
430        self.values.insert(s.clone(), v);
431    }
432
433    pub fn get(&self, s: &Symbol) -> Option<i64> {
434        self.values.get(s).copied()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn as_known_positive_gte() {
444        let s = SymbolScope::default();
445        assert_eq!(
446            parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
447            Some(s.parse_tdim("S").unwrap())
448        );
449    }
450
451    #[test]
452    fn as_known_positive_gt() {
453        let s = SymbolScope::default();
454        assert_eq!(
455            parse_assertion(&s, "S>0").unwrap().as_known_positive(),
456            Some(s.parse_tdim("S-1").unwrap())
457        );
458    }
459
460    #[test]
461    fn as_known_positive_lte() {
462        let s = SymbolScope::default();
463        assert_eq!(
464            parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
465            Some(s.parse_tdim("-S").unwrap())
466        );
467    }
468
469    #[test]
470    fn as_known_positive_lt() {
471        let s = SymbolScope::default();
472        assert_eq!(
473            parse_assertion(&s, "S<0").unwrap().as_known_positive(),
474            Some(s.parse_tdim("-S - 1").unwrap())
475        );
476    }
477
478    #[test]
479    fn prove_positive_0() {
480        let s = SymbolScope::default();
481        assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
482    }
483
484    #[test]
485    fn prove_positive_1() {
486        let s = SymbolScope::default();
487        assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
488    }
489
490    #[test]
491    fn prove_positive_neg1() {
492        let s = SymbolScope::default();
493        assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
494    }
495
496    #[test]
497    fn prove_positive_add_0() {
498        let s = SymbolScope::default();
499        assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
500    }
501}