1use std::{
2 f64::EPSILON,
3 fmt::Display,
4 ops::{Add, Div, Mul, Sub},
5};
6
7use anyhow::Result;
8use fraction::GenericFraction;
9use thiserror::Error;
10use z3::ast::Real;
11
12use super::{
13 context::{LayoutContext, Z3BuildContext},
14 prop::{Prop, PropVariant},
15};
16use std::fmt::Debug;
17
18#[derive(Copy, Clone)]
20pub struct Measure<'a> {
21 pub ctx: &'a LayoutContext,
22 pub(super) variant: &'a MeasureVariant<'a>,
23}
24
25impl<'a> Debug for Measure<'a> {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "Measure {{ {:?} }}", self.variant)
28 }
29}
30
31#[derive(Error, Debug)]
32pub enum MeasureError {
33 #[error("bad const")]
34 BadConst,
35}
36
37#[derive(Copy, Clone, Debug)]
38pub enum MeasureVariant<'a> {
39 Unbound,
40 Const(i32, i32),
41 Add(Measure<'a>, Measure<'a>),
42 Sub(Measure<'a>, Measure<'a>),
43 Mul(Measure<'a>, Measure<'a>),
44 Div(Measure<'a>, Measure<'a>),
45 Select(Prop<'a>, Measure<'a>, Measure<'a>),
46}
47
48struct UnsafelyAssumeThreadSafe<T>(T);
49unsafe impl<T> Send for UnsafelyAssumeThreadSafe<T> {}
50unsafe impl<T> Sync for UnsafelyAssumeThreadSafe<T> {}
51
52static SMALL_MEASURE_CONSTS: UnsafelyAssumeThreadSafe<[MeasureVariant<'static>; 16]> =
53 UnsafelyAssumeThreadSafe([
54 MeasureVariant::Const(0, 1),
55 MeasureVariant::Const(1, 1),
56 MeasureVariant::Const(2, 1),
57 MeasureVariant::Const(3, 1),
58 MeasureVariant::Const(4, 1),
59 MeasureVariant::Const(5, 1),
60 MeasureVariant::Const(6, 1),
61 MeasureVariant::Const(7, 1),
62 MeasureVariant::Const(8, 1),
63 MeasureVariant::Const(9, 1),
64 MeasureVariant::Const(10, 1),
65 MeasureVariant::Const(11, 1),
66 MeasureVariant::Const(12, 1),
67 MeasureVariant::Const(13, 1),
68 MeasureVariant::Const(14, 1),
69 MeasureVariant::Const(15, 1),
70 ]);
71
72#[allow(dead_code)]
73impl<'a> Measure<'a> {
74 pub fn zero(ctx: &'a LayoutContext) -> Self {
75 Measure {
76 ctx,
77 variant: &SMALL_MEASURE_CONSTS.0[0],
78 }
79 }
80
81 pub fn new_const(ctx: &'a LayoutContext, value: f64) -> Result<Self, MeasureError> {
82 let value = ((value * 100.0) as i64) as f64 / 100.0;
83
84 if ((value as i64) as f64 - value).abs() < EPSILON {
86 let candidates = &SMALL_MEASURE_CONSTS.0;
87 let index = value as i64;
88 if index >= 0 && index < candidates.len() as i64 {
89 return Ok(Measure {
90 ctx,
91 variant: &candidates[index as usize],
92 });
93 }
94 }
95
96 let frac = GenericFraction::<i32>::from(value);
97 let sign: i32 = if value < 0.0 { -1 } else { 1 };
98 Ok(Measure {
99 ctx,
100 variant: ctx.alloc.alloc(MeasureVariant::Const(
101 *frac.numer().ok_or_else(|| MeasureError::BadConst)? * sign,
102 *frac.denom().ok_or_else(|| MeasureError::BadConst)?,
103 )),
104 })
105 }
106
107 pub fn new_unbound(ctx: &'a LayoutContext) -> Self {
108 Measure {
109 ctx,
110 variant: ctx.alloc.alloc(MeasureVariant::Unbound),
111 }
112 }
113
114 pub fn is_unbound(&self) -> bool {
115 match self.variant {
116 &MeasureVariant::Unbound => true,
117 _ => false,
118 }
119 }
120
121 pub fn build_z3<'ctx>(self, build_ctx: &mut Z3BuildContext<'ctx>) -> Result<Real<'ctx>> {
122 let key = self.variant as *const _ as usize;
123 if let Some(x) = build_ctx.measure_cache.get(&key) {
124 return Ok(x.clone());
125 }
126 let res = self.do_build_z3(build_ctx)?;
127 build_ctx.measure_cache.insert(key, res.clone());
128 Ok(res)
129 }
130
131 fn do_build_z3<'ctx>(self, build_ctx: &mut Z3BuildContext<'ctx>) -> Result<Real<'ctx>> {
132 use MeasureVariant as V;
133 let z3_ctx = build_ctx.z3_ctx;
134 Ok(match *self.variant {
135 V::Unbound => Real::fresh_const(z3_ctx, "measure_"),
136 V::Const(num, den) => Real::from_real(z3_ctx, num, den),
137 V::Add(left, right) => left.build_z3(build_ctx)?.add(right.build_z3(build_ctx)?),
138 V::Sub(left, right) => left.build_z3(build_ctx)?.sub(right.build_z3(build_ctx)?),
139 V::Mul(left, right) => left.build_z3(build_ctx)?.mul(right.build_z3(build_ctx)?),
140 V::Div(left, right) => left.build_z3(build_ctx)?.div(right.build_z3(build_ctx)?),
141 V::Select(condition, left, right) => condition
142 .build_z3(build_ctx)?
143 .ite(&left.build_z3(build_ctx)?, &right.build_z3(build_ctx)?),
144 })
145 }
146
147 pub fn prop_eq(self, that: Self) -> Prop<'a> {
148 Prop {
149 ctx: self.ctx,
150 variant: self.ctx.alloc.alloc(PropVariant::Eq(self, that)),
151 weight: 10,
152 }
153 }
154
155 pub fn prop_lt(self, that: Self) -> Prop<'a> {
156 Prop {
157 ctx: self.ctx,
158 variant: self.ctx.alloc.alloc(PropVariant::Lt(self, that)),
159 weight: 10,
160 }
161 }
162
163 pub fn prop_le(self, that: Self) -> Prop<'a> {
164 Prop {
165 ctx: self.ctx,
166 variant: self.ctx.alloc.alloc(PropVariant::Le(self, that)),
167 weight: 10,
168 }
169 }
170
171 pub fn prop_gt(self, that: Self) -> Prop<'a> {
172 Prop {
173 ctx: self.ctx,
174 variant: self.ctx.alloc.alloc(PropVariant::Gt(self, that)),
175 weight: 10,
176 }
177 }
178
179 pub fn prop_ge(self, that: Self) -> Prop<'a> {
180 Prop {
181 ctx: self.ctx,
182 variant: self.ctx.alloc.alloc(PropVariant::Ge(self, that)),
183 weight: 10,
184 }
185 }
186
187 pub fn min(self, that: Self) -> Measure<'a> {
188 self.prop_lt(that).select(self, that)
189 }
190
191 pub fn max(self, that: Self) -> Measure<'a> {
192 self.prop_gt(that).select(self, that)
193 }
194}
195
196impl<'a> Display for Measure<'a> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 match self.variant {
199 MeasureVariant::Unbound => write!(f, "<{:p}>", self.variant),
200 MeasureVariant::Const(num, den) => write!(f, "{}", *num as f64 / *den as f64),
201 MeasureVariant::Add(l, r)
202 if r.variant as *const _ == &SMALL_MEASURE_CONSTS.0[0] as *const _ =>
203 {
204 write!(f, "{}", l)
205 }
206 MeasureVariant::Add(l, r) => write!(f, "({} + {})", l, r),
207 MeasureVariant::Sub(l, r)
208 if r.variant as *const _ == &SMALL_MEASURE_CONSTS.0[0] as *const _ =>
209 {
210 write!(f, "{}", l)
211 }
212 MeasureVariant::Sub(l, r) => write!(f, "({} - {})", l, r),
213 MeasureVariant::Mul(l, r) => write!(f, "({} * {})", l, r),
214 MeasureVariant::Div(l, r) => write!(f, "({} / {})", l, r),
215 MeasureVariant::Select(cond, l, r) => match cond.variant {
216 PropVariant::Lt(cond_l, cond_r)
217 if cond_l.variant as *const _ == l.variant as *const _
218 && cond_r.variant as *const _ == r.variant as *const _ =>
219 {
220 write!(f, "(min {} {})", l, r)
221 }
222 PropVariant::Gt(cond_l, cond_r)
223 if cond_l.variant as *const _ == l.variant as *const _
224 && cond_r.variant as *const _ == r.variant as *const _ =>
225 {
226 write!(f, "(max {} {})", l, r)
227 }
228 _ => write!(f, "(select ({}) {} {})", cond, l, r),
229 },
230 }
231 }
232}
233
234impl<'a> Add for Measure<'a> {
235 type Output = Self;
236
237 fn add(self, other: Self) -> Self {
238 Self {
239 ctx: self.ctx,
240 variant: self.ctx.alloc.alloc(MeasureVariant::Add(self, other)),
241 }
242 }
243}
244
245impl<'a> Add<f64> for Measure<'a> {
246 type Output = Self;
247
248 fn add(self, other: f64) -> Self {
249 self + Measure::new_const(self.ctx, other).unwrap()
250 }
251}
252
253impl<'a> Sub for Measure<'a> {
254 type Output = Self;
255
256 fn sub(self, other: Self) -> Self {
257 Self {
258 ctx: self.ctx,
259 variant: self.ctx.alloc.alloc(MeasureVariant::Sub(self, other)),
260 }
261 }
262}
263
264impl<'a> Sub<f64> for Measure<'a> {
265 type Output = Self;
266
267 fn sub(self, other: f64) -> Self {
268 self - Measure::new_const(self.ctx, other).unwrap()
269 }
270}
271
272impl<'a> Mul for Measure<'a> {
273 type Output = Self;
274
275 fn mul(self, other: Self) -> Self {
276 Self {
277 ctx: self.ctx,
278 variant: self.ctx.alloc.alloc(MeasureVariant::Mul(self, other)),
279 }
280 }
281}
282
283impl<'a> Mul<f64> for Measure<'a> {
284 type Output = Self;
285
286 fn mul(self, other: f64) -> Self {
287 self * Measure::new_const(self.ctx, other).unwrap()
288 }
289}
290
291impl<'a> Div for Measure<'a> {
292 type Output = Self;
293
294 fn div(self, other: Self) -> Self {
295 Self {
296 ctx: self.ctx,
297 variant: self.ctx.alloc.alloc(MeasureVariant::Div(self, other)),
298 }
299 }
300}
301
302impl<'a> Div<f64> for Measure<'a> {
303 type Output = Self;
304
305 fn div(self, other: f64) -> Self {
306 self / Measure::new_const(self.ctx, other).unwrap()
307 }
308}