1use itertools::Itertools;
2use parking_lot::ReentrantMutex;
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::fmt::{self, Display};
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::sync::{Arc, Weak};
8use string_interner::DefaultStringInterner;
9use string_interner::Symbol as _;
10
11use crate::TractResult;
12
13use super::parse::parse_assertion;
14use super::{Assertion, TDim, parse_tdim};
15
16static SCOPE_COUNTER: AtomicUsize = AtomicUsize::new(0);
17
18pub struct SymbolScopeInner {
22 has_scenarios: AtomicBool,
26 data: ReentrantMutex<RefCell<SymbolScopeData>>,
27}
28
29impl Default for SymbolScopeInner {
30 fn default() -> Self {
31 SymbolScopeInner { has_scenarios: AtomicBool::new(false), data: Default::default() }
32 }
33}
34
35impl std::ops::Deref for SymbolScopeInner {
36 type Target = ReentrantMutex<RefCell<SymbolScopeData>>;
37 fn deref(&self) -> &Self::Target {
38 &self.data
39 }
40}
41
42#[derive(Clone, Default)]
43pub struct SymbolScope(pub Arc<SymbolScopeInner>);
44
45impl PartialEq for SymbolScope {
46 fn eq(&self, other: &Self) -> bool {
47 Arc::ptr_eq(&self.0, &other.0)
48 }
49}
50
51impl Eq for SymbolScope {}
52
53pub struct SymbolScopeData {
54 id: usize,
55 table: DefaultStringInterner,
56 assertions: Vec<Assertion>,
57 scenarios: Vec<(String, Vec<Assertion>)>,
58}
59
60impl Default for SymbolScopeData {
61 fn default() -> Self {
62 SymbolScopeData {
63 id: SCOPE_COUNTER.fetch_add(1, Ordering::Relaxed),
64 table: DefaultStringInterner::default(),
65 assertions: Vec::new(),
66 scenarios: Vec::new(),
67 }
68 }
69}
70
71impl SymbolScope {
72 pub fn id(&self) -> usize {
73 let locked = self.0.lock();
74 let locked = locked.borrow();
75 locked.id
76 }
77
78 pub fn proof_cache_session(&self) -> ProofCacheSession {
79 ProofCacheSession::new(self.id())
80 }
81
82 pub fn get(&self, name: &str) -> Option<Symbol> {
83 let locked = self.0.lock();
84 let locked = locked.borrow();
85 locked.table.get(name).map(|sym| Symbol(Arc::downgrade(&self.0), sym))
86 }
87
88 pub fn coord_sym(&self, k: usize) -> Symbol {
90 self.sym(&format!("🎯{k}"))
91 }
92
93 pub fn sym(&self, name: &str) -> Symbol {
94 let locked = self.0.lock();
95 let mut locked = locked.borrow_mut();
96 let sym = locked.table.get_or_intern(name);
97 Symbol(Arc::downgrade(&self.0), sym)
98 }
99
100 pub fn new_with_prefix(&self, prefix: &str) -> Symbol {
101 let locked = self.0.lock();
102 let mut locked = locked.borrow_mut();
103 let sym = if locked.table.get(prefix).is_none() {
104 locked.table.get_or_intern(prefix)
105 } else {
106 let mut i = 0;
107 loop {
108 let s = format!("{prefix}_{i}");
109 if locked.table.get(&s).is_none() {
110 break locked.table.get_or_intern(s);
111 }
112 i += 1;
113 }
114 };
115 Symbol(Arc::downgrade(&self.0), sym)
116 }
117
118 pub fn parse_tdim(&self, input: impl AsRef<str>) -> TractResult<TDim> {
119 parse_tdim(self, input.as_ref())
120 }
121
122 pub fn add_assertion(&self, assert: impl Into<String>) -> TractResult<()> {
123 let assert = assert.into();
124 let assert = parse_assertion(self, &assert)?;
125 let locked = self.0.lock();
126 let mut locked = locked.borrow_mut();
127 locked.assertions.push(assert);
128 Ok(())
129 }
130
131 pub fn with_assertion(self, assert: impl Into<String>) -> TractResult<Self> {
132 self.add_assertion(assert)?;
133 Ok(self)
134 }
135
136 pub fn all_assertions(&self) -> Vec<Assertion> {
137 let locked = self.0.lock();
138 let locked = locked.borrow();
139 locked.assertions.clone()
140 }
141
142 pub fn all_scenarios(&self) -> impl IntoIterator<Item = (String, Vec<Assertion>)> {
143 let locked = self.0.lock();
144 let locked = locked.borrow();
145 locked.scenarios.clone()
146 }
147
148 pub fn add_scenario(&self, scenario: impl Into<String>) -> TractResult<()> {
149 let locked = self.0.lock();
150 let mut locked = locked.borrow_mut();
151 let s = scenario.into();
152 if !locked.scenarios.iter().any(|sc| sc.0 == s) {
153 locked.scenarios.push((s, vec![]));
154 self.0.has_scenarios.store(true, Ordering::Relaxed);
155 }
156 Ok(())
157 }
158
159 pub fn add_scenario_assertion(
160 &self,
161 scenario: impl Into<String>,
162 assertion: impl Into<String>,
163 ) -> TractResult<()> {
164 let assert = parse_assertion(self, &assertion.into())?;
165 let s = scenario.into();
166 let locked = self.0.lock();
167 let mut locked = locked.borrow_mut();
168 if let Some(s) = locked.scenarios.iter_mut().find(|sc| sc.0 == s) {
169 s.1.push(assert);
170 } else {
171 locked.scenarios.push((s, vec![assert]));
172 self.0.has_scenarios.store(true, Ordering::Relaxed);
173 }
174 Ok(())
175 }
176
177 pub fn with_scenario_assertion(
178 self,
179 scenario: impl Into<String>,
180 assertion: impl Into<String>,
181 ) -> TractResult<Self> {
182 self.add_scenario_assertion(scenario, assertion)?;
183 Ok(self)
184 }
185
186 pub fn with_scenario(self, scenario: impl Into<String>) -> TractResult<Self> {
187 self.add_scenario(scenario)?;
188 Ok(self)
189 }
190
191 pub fn all_symbols(&self) -> Vec<Symbol> {
192 self.0
193 .lock()
194 .borrow()
195 .table
196 .into_iter()
197 .map(|is| Symbol(Arc::downgrade(&self.0), is.0))
198 .collect()
199 }
200
201 pub fn guess_scenario(&self, values: &SymbolValues) -> TractResult<Option<usize>> {
202 if !self.0.has_scenarios.load(Ordering::Relaxed) {
204 return Ok(None);
205 }
206 let locked = self.0.lock();
207 let locked = locked.borrow();
208 if locked.scenarios.is_empty() {
209 return Ok(None);
210 }
211 let mut maybe = None;
212 for (ix, (_name, assertions)) in locked.scenarios.iter().enumerate() {
213 if assertions.iter().any(|a| a.check(values) == Some(false)) {
214 continue;
215 } else if assertions.iter().all(|a| a.check(values) == Some(true)) {
216 return Ok(Some(ix));
217 } else if maybe.is_none() {
218 maybe = Some(ix);
219 } else {
220 return Ok(None);
221 }
222 }
223 if maybe.is_some() {
224 Ok(maybe)
225 } else {
226 anyhow::bail!("No possible scenario");
227 }
228 }
229}
230
231thread_local! {
232 static PROOF_CACHE: RefCell<Option<ProofCache>> = const { RefCell::new(None) };
233}
234
235struct ProofCache {
236 scope_id: usize,
237 depth: usize,
238 cache: HashMap<TDim, bool>,
239}
240
241pub struct ProofCacheSession {
242 active: bool,
243}
244
245impl ProofCacheSession {
246 pub fn new(scope_id: usize) -> Self {
247 let active = PROOF_CACHE.with(|cell| {
248 let mut borrow = cell.borrow_mut();
249 match &mut *borrow {
250 None => {
251 *borrow = Some(ProofCache { scope_id, depth: 1, cache: HashMap::new() });
252 true
253 }
254 Some(pc) if pc.scope_id == scope_id => {
255 pc.depth += 1;
256 true
257 }
258 Some(_) => false,
259 }
260 });
261 ProofCacheSession { active }
262 }
263}
264
265impl Drop for ProofCacheSession {
266 fn drop(&mut self) {
267 if !self.active {
268 return;
269 }
270 PROOF_CACHE.with(|cell| {
271 let mut borrow = cell.borrow_mut();
272 if let Some(pc) = &mut *borrow {
273 pc.depth -= 1;
274 if pc.depth == 0 {
275 *borrow = None;
276 }
277 }
278 });
279 }
280}
281
282impl SymbolScopeData {
283 pub fn all_assertions(&self) -> &[Assertion] {
284 &self.assertions
285 }
286
287 pub fn assertions(&self, scenario: Option<&str>) -> impl Iterator<Item = &'_ Assertion> {
288 self.assertions.iter().chain(
289 scenario
290 .and_then(|s| self.scenarios.iter().find(|s2| s2.0 == s))
291 .map(|s| &*s.1)
292 .unwrap_or(&[])
293 .iter(),
294 )
295 }
296
297 pub fn scenarios(&self) -> impl Iterator<Item = &'_ str> {
298 self.scenarios.iter().map(|s| &*s.0)
299 }
300
301 pub fn scenario(&self, s: &str) -> impl Iterator<Item = &'_ Assertion> {
302 self.scenarios.iter().find(|sc| sc.0 == s).map(|sc| &*sc.1).unwrap_or(&[]).iter()
303 }
304
305 pub fn resolving<R>(&self, sym: &Symbol, f: impl FnOnce(&str) -> R) -> Option<R> {
306 self.table.resolve(sym.1).map(f)
307 }
308
309 #[allow(clippy::mutable_key_type)]
310 pub fn prove_positive_or_zero(&self, t: &TDim) -> bool {
311 if let TDim::Val(v) = t {
312 return *v >= 0;
313 }
314 let cached = PROOF_CACHE.with(|cell| {
315 let borrow = cell.borrow();
316 if let Some(pc) = &*borrow {
317 debug_assert_eq!(pc.scope_id, self.id, "ProofCacheSession scope_id mismatch");
318 pc.cache.get(t).copied()
319 } else {
320 None
321 }
322 });
323 if let Some(result) = cached {
324 return result;
325 }
326 let result = self.prove_positive_or_zero_inner(t);
327 PROOF_CACHE.with(|cell| {
328 let mut borrow = cell.borrow_mut();
329 if let Some(pc) = &mut *borrow {
330 pc.cache.insert(t.clone(), result);
331 }
332 });
333 result
334 }
335
336 #[allow(clippy::mutable_key_type)]
337 fn prove_positive_or_zero_inner(&self, t: &TDim) -> bool {
338 self.prove_positive_or_zero_inner_with_extra(t, &[])
339 }
340
341 #[allow(clippy::mutable_key_type)]
342 fn prove_positive_or_zero_inner_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
343 let positives = self
344 .assertions
345 .iter()
346 .chain(extra.iter())
347 .filter_map(|i| i.as_known_positive())
348 .collect_vec();
349 let mut visited = vec![];
350 let mut todo = vec![t.clone()];
351 while let Some(t) = todo.pop() {
352 if t.to_i64().is_ok_and(|i| i >= 0) {
353 return true;
354 }
355 if t.inclusive_bound(self, false).is_some_and(|l| l >= 0) {
356 return true;
357 }
358 if let TDim::Div(a, q) = &t {
360 if *q >= 1 && self.prove_positive_or_zero_inner_with_extra(a, extra) {
361 return true;
362 }
363 }
364 let syms = t.symbols();
365 for s in syms {
366 let me = t.guess_slope(&s);
367 for pos in &positives {
368 if pos.symbols().contains(&s) {
369 let other = pos.guess_slope(&s);
370 if me.0.signum() == other.0.signum() {
371 let new = t.clone() * me.1 * other.0.abs()
372 - pos.clone() * me.0.abs() * other.1;
373 if !visited.contains(&new) {
374 todo.push(new);
375 }
376 }
377 }
378 }
379 }
380 visited.push(t);
381 if visited.len() > 10 {
382 break;
383 }
384 }
385 false
386 }
387
388 pub(crate) fn prove_positive_or_zero_with_extra(&self, t: &TDim, extra: &[Assertion]) -> bool {
389 if let TDim::Val(v) = t {
390 return *v >= 0;
391 }
392 self.prove_positive_or_zero_inner_with_extra(t, extra)
394 }
395
396 pub(crate) fn prove_strict_positive_with_extra(&self, b: &TDim, extra: &[Assertion]) -> bool {
397 self.prove_positive_or_zero_with_extra(&(b.clone() - 1), extra)
398 }
399}
400
401impl fmt::Debug for SymbolScope {
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 let locked = self.0.lock();
404 let locked = locked.borrow();
405 write!(
406 f,
407 "symbols: {}; assertions: {}; {}",
408 locked.table.into_iter().map(|(_, s)| s).sorted().join(", "),
409 locked.assertions.iter().map(|s| s.to_string()).sorted().join(", "),
410 locked
411 .scenarios
412 .iter()
413 .map(|s| format!(
414 "{}: {}",
415 s.0,
416 s.1.iter().map(|s| s.to_string()).sorted().join(", ")
417 ))
418 .join(" ; "),
419 )
420 }
421}
422
423#[derive(Clone)]
424pub struct Symbol(Weak<SymbolScopeInner>, string_interner::DefaultSymbol);
425
426impl Eq for Symbol {}
427
428impl PartialEq for Symbol {
429 fn eq(&self, other: &Self) -> bool {
430 self.1 == other.1
431 }
432}
433
434impl Symbol {
435 pub fn scope(&self) -> Option<SymbolScope> {
436 self.0.upgrade().map(SymbolScope)
437 }
438}
439
440impl PartialOrd for Symbol {
441 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
442 Some(self.cmp(other))
443 }
444}
445
446impl Ord for Symbol {
447 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
448 self.1.cmp(&other.1)
449 }
450}
451
452impl std::hash::Hash for Symbol {
453 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
454 self.1.hash(state)
455 }
456}
457
458impl std::fmt::Display for Symbol {
459 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
460 if let Some(scope) = self.scope() {
461 let lock = scope.0.lock();
462 let lock = lock.borrow();
463 if let Some(s) = lock.table.resolve(self.1) {
464 return write!(f, "{s}");
465 }
466 }
467 write!(f, "<Sym{}>", self.1.to_usize())
468 }
469}
470
471impl fmt::Debug for Symbol {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 Display::fmt(&self, f)
474 }
475}
476
477#[derive(Clone, Debug, Default)]
478pub struct SymbolValues {
479 values: HashMap<Symbol, i64>,
480}
481
482impl SymbolValues {
483 pub fn with(mut self, s: &Symbol, v: i64) -> Self {
484 self.set(s, v);
485 self
486 }
487
488 pub fn set(&mut self, s: &Symbol, v: i64) {
489 self.values.insert(s.clone(), v);
490 }
491
492 pub fn get(&self, s: &Symbol) -> Option<i64> {
493 self.values.get(s).copied()
494 }
495
496 pub fn to_dim_map(&self) -> HashMap<Symbol, TDim> {
499 self.values.iter().map(|(s, v)| (s.clone(), TDim::Val(*v))).collect()
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn as_known_positive_gte() {
509 let s = SymbolScope::default();
510 assert_eq!(
511 parse_assertion(&s, "S>=0").unwrap().as_known_positive(),
512 Some(s.parse_tdim("S").unwrap())
513 );
514 }
515
516 #[test]
517 fn as_known_positive_gt() {
518 let s = SymbolScope::default();
519 assert_eq!(
520 parse_assertion(&s, "S>0").unwrap().as_known_positive(),
521 Some(s.parse_tdim("S-1").unwrap())
522 );
523 }
524
525 #[test]
526 fn as_known_positive_lte() {
527 let s = SymbolScope::default();
528 assert_eq!(
529 parse_assertion(&s, "S<=0").unwrap().as_known_positive(),
530 Some(s.parse_tdim("-S").unwrap())
531 );
532 }
533
534 #[test]
535 fn as_known_positive_lt() {
536 let s = SymbolScope::default();
537 assert_eq!(
538 parse_assertion(&s, "S<0").unwrap().as_known_positive(),
539 Some(s.parse_tdim("-S - 1").unwrap())
540 );
541 }
542
543 #[test]
544 fn prove_positive_0() {
545 let s = SymbolScope::default();
546 assert!(s.parse_tdim("0").unwrap().prove_positive_or_zero());
547 }
548
549 #[test]
550 fn prove_positive_1() {
551 let s = SymbolScope::default();
552 assert!(s.parse_tdim("1").unwrap().prove_positive_or_zero());
553 }
554
555 #[test]
556 fn prove_positive_neg1() {
557 let s = SymbolScope::default();
558 assert!(!s.parse_tdim("-1").unwrap().prove_positive_or_zero());
559 }
560
561 #[test]
562 fn prove_positive_add_0() {
563 let s = SymbolScope::default();
564 assert!(!s.parse_tdim("s+1").unwrap().prove_positive_or_zero());
565 }
566}