1use num_traits::NumOps;
5
6use std::{fmt, str::FromStr};
7
8use crate::{
9 error::{ErrorKind, ErrorLocation, OpErrors},
10 PrimitiveType, Type,
11};
12use arithmetic_parser::{BinaryOp, UnaryOp};
13
14mod constraints;
15mod substitutions;
16
17pub(crate) use self::constraints::CompleteConstraints;
18pub use self::constraints::{
19 Constraint, ConstraintSet, LinearType, Linearity, ObjectSafeConstraint, Ops, StructConstraint,
20};
21pub use self::substitutions::Substitutions;
22
23pub trait MapPrimitiveType<Val> {
28 type Prim: PrimitiveType;
30
31 fn type_of_literal(&self, lit: &Val) -> Self::Prim;
33}
34
35pub trait TypeArithmetic<Prim: PrimitiveType> {
42 fn process_unary_op(
44 &self,
45 substitutions: &mut Substitutions<Prim>,
46 context: &UnaryOpContext<Prim>,
47 errors: OpErrors<'_, Prim>,
48 ) -> Type<Prim>;
49
50 fn process_binary_op(
52 &self,
53 substitutions: &mut Substitutions<Prim>,
54 context: &BinaryOpContext<Prim>,
55 errors: OpErrors<'_, Prim>,
56 ) -> Type<Prim>;
57}
58
59#[derive(Debug, Clone)]
63pub struct UnaryOpContext<Prim: PrimitiveType> {
64 pub op: UnaryOp,
66 pub arg: Type<Prim>,
68}
69
70#[derive(Debug, Clone)]
74pub struct BinaryOpContext<Prim: PrimitiveType> {
75 pub op: BinaryOp,
77 pub lhs: Type<Prim>,
79 pub rhs: Type<Prim>,
81}
82
83pub trait WithBoolean: PrimitiveType {
85 const BOOL: Self;
87}
88
89#[derive(Debug, Clone, Copy, Default)]
92pub struct BoolArithmetic;
93
94impl<Prim: WithBoolean> TypeArithmetic<Prim> for BoolArithmetic {
95 fn process_unary_op<'a>(
100 &self,
101 substitutions: &mut Substitutions<Prim>,
102 context: &UnaryOpContext<Prim>,
103 mut errors: OpErrors<'_, Prim>,
104 ) -> Type<Prim> {
105 let op = context.op;
106 if op == UnaryOp::Not {
107 substitutions.unify(&Type::BOOL, &context.arg, errors);
108 Type::BOOL
109 } else {
110 let err = ErrorKind::unsupported(op);
111 errors.push(err);
112 substitutions.new_type_var()
113 }
114 }
115
116 fn process_binary_op(
123 &self,
124 substitutions: &mut Substitutions<Prim>,
125 context: &BinaryOpContext<Prim>,
126 mut errors: OpErrors<'_, Prim>,
127 ) -> Type<Prim> {
128 match context.op {
129 BinaryOp::Eq | BinaryOp::NotEq => {
130 substitutions.unify(&context.lhs, &context.rhs, errors);
131 Type::BOOL
132 }
133
134 BinaryOp::And | BinaryOp::Or => {
135 substitutions.unify(
136 &Type::BOOL,
137 &context.lhs,
138 errors.with_location(ErrorLocation::Lhs),
139 );
140 substitutions.unify(
141 &Type::BOOL,
142 &context.rhs,
143 errors.with_location(ErrorLocation::Rhs),
144 );
145 Type::BOOL
146 }
147
148 _ => {
149 errors.push(ErrorKind::unsupported(context.op));
150 substitutions.new_type_var()
151 }
152 }
153 }
154}
155
156#[derive(Debug)]
158pub struct OpConstraintSettings<'a, Prim: PrimitiveType> {
159 pub lin: &'a dyn Constraint<Prim>,
161 pub ops: &'a dyn Constraint<Prim>,
163}
164
165impl<Prim: PrimitiveType> Clone for OpConstraintSettings<'_, Prim> {
166 fn clone(&self) -> Self {
167 Self {
168 lin: self.lin,
169 ops: self.ops,
170 }
171 }
172}
173
174impl<Prim: PrimitiveType> Copy for OpConstraintSettings<'_, Prim> {}
175
176#[derive(Debug, Clone, Copy, PartialEq)]
178pub enum Num {
179 Num,
181 Bool,
183}
184
185impl fmt::Display for Num {
186 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
187 formatter.write_str(match self {
188 Self::Num => "Num",
189 Self::Bool => "Bool",
190 })
191 }
192}
193
194impl FromStr for Num {
195 type Err = anyhow::Error;
196
197 fn from_str(s: &str) -> Result<Self, Self::Err> {
198 match s {
199 "Num" => Ok(Self::Num),
200 "Bool" => Ok(Self::Bool),
201 _ => Err(anyhow::anyhow!("Expected `Num` or `Bool`")),
202 }
203 }
204}
205
206impl PrimitiveType for Num {
207 fn well_known_constraints() -> ConstraintSet<Self> {
208 let mut constraints = ConstraintSet::default();
209 constraints.insert_object_safe(Linearity);
210 constraints.insert(Ops);
211 constraints
212 }
213}
214
215impl WithBoolean for Num {
216 const BOOL: Self = Self::Bool;
217}
218
219impl LinearType for Num {
221 fn is_linear(&self) -> bool {
222 matches!(self, Self::Num)
223 }
224}
225
226#[derive(Debug, Clone)]
246pub struct NumArithmetic {
247 comparisons_enabled: bool,
248}
249
250impl NumArithmetic {
251 pub const fn without_comparisons() -> Self {
253 Self {
254 comparisons_enabled: false,
255 }
256 }
257
258 pub const fn with_comparisons() -> Self {
260 Self {
261 comparisons_enabled: true,
262 }
263 }
264
265 pub fn unify_binary_op<Prim: PrimitiveType>(
274 substitutions: &mut Substitutions<Prim>,
275 context: &BinaryOpContext<Prim>,
276 mut errors: OpErrors<'_, Prim>,
277 settings: OpConstraintSettings<'_, Prim>,
278 ) -> Type<Prim> {
279 let lhs_ty = &context.lhs;
280 let rhs_ty = &context.rhs;
281 let resolved_lhs_ty = substitutions.fast_resolve(lhs_ty);
282 let resolved_rhs_ty = substitutions.fast_resolve(rhs_ty);
283
284 match (
285 resolved_lhs_ty.is_primitive(),
286 resolved_rhs_ty.is_primitive(),
287 ) {
288 (Some(true), Some(false)) => {
289 let resolved_rhs_ty = resolved_rhs_ty.clone();
290 settings
291 .lin
292 .visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
293 .visit_type(lhs_ty);
294 settings
295 .lin
296 .visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
297 .visit_type(rhs_ty);
298 resolved_rhs_ty
299 }
300 (Some(false), Some(true)) => {
301 let resolved_lhs_ty = resolved_lhs_ty.clone();
302 settings
303 .lin
304 .visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
305 .visit_type(lhs_ty);
306 settings
307 .lin
308 .visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
309 .visit_type(rhs_ty);
310 resolved_lhs_ty
311 }
312 _ => {
313 let lhs_is_valid = errors.with_location(ErrorLocation::Lhs).check(|errors| {
314 settings
315 .ops
316 .visitor(substitutions, errors)
317 .visit_type(lhs_ty);
318 });
319 let rhs_is_valid = errors.with_location(ErrorLocation::Rhs).check(|errors| {
320 settings
321 .ops
322 .visitor(substitutions, errors)
323 .visit_type(rhs_ty);
324 });
325
326 if lhs_is_valid && rhs_is_valid {
327 substitutions.unify(lhs_ty, rhs_ty, errors);
328 }
329 if lhs_is_valid {
330 lhs_ty.clone()
331 } else {
332 rhs_ty.clone()
333 }
334 }
335 }
336 }
337
338 pub fn process_unary_op<Prim: WithBoolean>(
343 substitutions: &mut Substitutions<Prim>,
344 context: &UnaryOpContext<Prim>,
345 mut errors: OpErrors<'_, Prim>,
346 constraints: &impl Constraint<Prim>,
347 ) -> Type<Prim> {
348 match context.op {
349 UnaryOp::Not => BoolArithmetic.process_unary_op(substitutions, context, errors),
350 UnaryOp::Neg => {
351 constraints
352 .visitor(substitutions, errors)
353 .visit_type(&context.arg);
354 context.arg.clone()
355 }
356 _ => {
357 errors.push(ErrorKind::unsupported(context.op));
358 substitutions.new_type_var()
359 }
360 }
361 }
362
363 pub fn process_binary_op<Prim: WithBoolean>(
374 substitutions: &mut Substitutions<Prim>,
375 context: &BinaryOpContext<Prim>,
376 mut errors: OpErrors<'_, Prim>,
377 comparable_type: Option<Prim>,
378 settings: OpConstraintSettings<'_, Prim>,
379 ) -> Type<Prim> {
380 match context.op {
381 BinaryOp::And | BinaryOp::Or | BinaryOp::Eq | BinaryOp::NotEq => {
382 BoolArithmetic.process_binary_op(substitutions, context, errors)
383 }
384
385 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
386 Self::unify_binary_op(substitutions, context, errors, settings)
387 }
388
389 BinaryOp::Ge | BinaryOp::Le | BinaryOp::Lt | BinaryOp::Gt => {
390 if let Some(ty) = comparable_type {
391 let ty = Type::Prim(ty);
392 substitutions.unify(
393 &ty,
394 &context.lhs,
395 errors.with_location(ErrorLocation::Lhs),
396 );
397 substitutions.unify(
398 &ty,
399 &context.rhs,
400 errors.with_location(ErrorLocation::Rhs),
401 );
402 } else {
403 let err = ErrorKind::unsupported(context.op);
404 errors.push(err);
405 }
406 Type::BOOL
407 }
408
409 _ => {
410 errors.push(ErrorKind::unsupported(context.op));
411 substitutions.new_type_var()
412 }
413 }
414 }
415}
416
417impl<Val> MapPrimitiveType<Val> for NumArithmetic
418where
419 Val: Clone + NumOps + PartialEq,
420{
421 type Prim = Num;
422
423 fn type_of_literal(&self, _: &Val) -> Self::Prim {
424 Num::Num
425 }
426}
427
428impl TypeArithmetic<Num> for NumArithmetic {
429 fn process_unary_op<'a>(
430 &self,
431 substitutions: &mut Substitutions<Num>,
432 context: &UnaryOpContext<Num>,
433 errors: OpErrors<'_, Num>,
434 ) -> Type<Num> {
435 Self::process_unary_op(substitutions, context, errors, &Linearity)
436 }
437
438 fn process_binary_op<'a>(
439 &self,
440 substitutions: &mut Substitutions<Num>,
441 context: &BinaryOpContext<Num>,
442 errors: OpErrors<'_, Num>,
443 ) -> Type<Num> {
444 const OP_SETTINGS: OpConstraintSettings<'static, Num> = OpConstraintSettings {
445 lin: &Linearity,
446 ops: &Ops,
447 };
448
449 let comparable_type = if self.comparisons_enabled {
450 Some(Num::Num)
451 } else {
452 None
453 };
454
455 Self::process_binary_op(substitutions, context, errors, comparable_type, OP_SETTINGS)
456 }
457}