1use std::{collections::HashMap, fmt, marker::PhantomData};
4
5use crate::{
6 arith::Substitutions,
7 error::{ErrorKind, OpErrors},
8 visit::{self, Visit},
9 Function, Object, PrimitiveType, Slice, Tuple, Type, TypeVar,
10};
11
12pub trait Constraint<Prim: PrimitiveType>: fmt::Display + Send + Sync + 'static {
35 fn visitor<'r>(
43 &self,
44 substitutions: &'r mut Substitutions<Prim>,
45 errors: OpErrors<'r, Prim>,
46 ) -> Box<dyn Visit<Prim> + 'r>;
47
48 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>>;
52}
53
54impl<Prim: PrimitiveType> fmt::Debug for dyn Constraint<Prim> {
55 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
56 formatter
57 .debug_tuple("dyn Constraint")
58 .field(&self.to_string())
59 .finish()
60 }
61}
62
63impl<Prim: PrimitiveType> Clone for Box<dyn Constraint<Prim>> {
64 fn clone(&self) -> Self {
65 self.clone_boxed()
66 }
67}
68
69pub trait ObjectSafeConstraint<Prim: PrimitiveType>: Constraint<Prim> {}
78
79#[derive(Debug)]
134pub struct StructConstraint<Prim, C, F> {
135 constraint: C,
136 predicate: F,
137 deny_dyn_slices: bool,
138 _prim: PhantomData<Prim>,
139}
140
141impl<Prim, C, F> StructConstraint<Prim, C, F>
142where
143 Prim: PrimitiveType,
144 C: Constraint<Prim> + Clone,
145 F: Fn(&Prim) -> bool + 'static,
146{
147 pub fn new(constraint: C, predicate: F) -> Self {
150 Self {
151 constraint,
152 predicate,
153 deny_dyn_slices: false,
154 _prim: PhantomData,
155 }
156 }
157
158 pub fn deny_dyn_slices(mut self) -> Self {
160 self.deny_dyn_slices = true;
161 self
162 }
163
164 pub fn visitor<'r>(
166 self,
167 substitutions: &'r mut Substitutions<Prim>,
168 errors: OpErrors<'r, Prim>,
169 ) -> Box<dyn Visit<Prim> + 'r> {
170 Box::new(StructConstraintVisitor {
171 inner: self,
172 substitutions,
173 errors,
174 })
175 }
176}
177
178#[derive(Debug)]
179struct StructConstraintVisitor<'r, Prim: PrimitiveType, C, F> {
180 inner: StructConstraint<Prim, C, F>,
181 substitutions: &'r mut Substitutions<Prim>,
182 errors: OpErrors<'r, Prim>,
183}
184
185impl<'r, Prim, C, F> Visit<Prim> for StructConstraintVisitor<'r, Prim, C, F>
186where
187 Prim: PrimitiveType,
188 C: Constraint<Prim> + Clone,
189 F: Fn(&Prim) -> bool + 'static,
190{
191 fn visit_type(&mut self, ty: &Type<Prim>) {
192 match ty {
193 Type::Dyn(constraints) => {
194 if !constraints.inner.simple.contains(&self.inner.constraint) {
195 self.errors.push(ErrorKind::failed_constraint(
196 ty.clone(),
197 self.inner.constraint.clone(),
198 ));
199 }
200 }
201 _ => visit::visit_type(self, ty),
202 }
203 }
204
205 fn visit_var(&mut self, var: TypeVar) {
206 debug_assert!(var.is_free());
207 self.substitutions.insert_constraint(
208 var.index(),
209 &self.inner.constraint,
210 self.errors.by_ref(),
211 );
212
213 let resolved = self.substitutions.fast_resolve(&Type::Var(var)).clone();
214 if let Type::Var(_) = resolved {
215 } else {
217 visit::visit_type(self, &resolved);
218 }
219 }
220
221 fn visit_primitive(&mut self, primitive: &Prim) {
222 if !(self.inner.predicate)(primitive) {
223 self.errors.push(ErrorKind::failed_constraint(
224 Type::Prim(primitive.clone()),
225 self.inner.constraint.clone(),
226 ));
227 }
228 }
229
230 fn visit_tuple(&mut self, tuple: &Tuple<Prim>) {
231 if self.inner.deny_dyn_slices {
232 let middle_len = tuple.parts().1.map(Slice::len);
233 if let Some(middle_len) = middle_len {
234 if let Err(err) = self.substitutions.apply_static_len(middle_len) {
235 self.errors.push(err);
236 }
237 }
238 }
239
240 for (i, element) in tuple.element_types() {
241 self.errors.push_location(i);
242 self.visit_type(element);
243 self.errors.pop_location();
244 }
245 }
246
247 fn visit_object(&mut self, obj: &Object<Prim>) {
248 for (name, element) in obj.iter() {
249 self.errors.push_location(name);
250 self.visit_type(element);
251 self.errors.pop_location();
252 }
253 }
254
255 fn visit_function(&mut self, function: &Function<Prim>) {
256 self.errors.push(ErrorKind::failed_constraint(
257 function.clone().into(),
258 self.inner.constraint.clone(),
259 ));
260 }
261}
262
263#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Linearity;
270
271impl fmt::Display for Linearity {
272 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
273 formatter.write_str("Lin")
274 }
275}
276
277impl<Prim: LinearType> Constraint<Prim> for Linearity {
278 fn visitor<'r>(
279 &self,
280 substitutions: &'r mut Substitutions<Prim>,
281 errors: OpErrors<'r, Prim>,
282 ) -> Box<dyn Visit<Prim> + 'r> {
283 StructConstraint::new(*self, LinearType::is_linear).visitor(substitutions, errors)
284 }
285
286 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
287 Box::new(*self)
288 }
289}
290
291impl<Prim: LinearType> ObjectSafeConstraint<Prim> for Linearity {}
292
293pub trait LinearType: PrimitiveType {
296 fn is_linear(&self) -> bool;
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305pub struct Ops;
306
307impl fmt::Display for Ops {
308 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
309 formatter.write_str("Ops")
310 }
311}
312
313impl<Prim: LinearType> Constraint<Prim> for Ops {
314 fn visitor<'r>(
315 &self,
316 substitutions: &'r mut Substitutions<Prim>,
317 errors: OpErrors<'r, Prim>,
318 ) -> Box<dyn Visit<Prim> + 'r> {
319 StructConstraint::new(*self, LinearType::is_linear)
320 .deny_dyn_slices()
321 .visitor(substitutions, errors)
322 }
323
324 fn clone_boxed(&self) -> Box<dyn Constraint<Prim>> {
325 Box::new(*self)
326 }
327}
328
329#[derive(Debug, Clone)]
334pub struct ConstraintSet<Prim: PrimitiveType> {
335 inner: HashMap<String, (Box<dyn Constraint<Prim>>, bool)>,
336}
337
338impl<Prim: PrimitiveType> Default for ConstraintSet<Prim> {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344impl<Prim: PrimitiveType> PartialEq for ConstraintSet<Prim> {
345 fn eq(&self, other: &Self) -> bool {
346 if self.inner.len() == other.inner.len() {
347 self.inner.keys().all(|key| other.inner.contains_key(key))
348 } else {
349 false
350 }
351 }
352}
353
354impl<Prim: PrimitiveType> fmt::Display for ConstraintSet<Prim> {
355 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
356 let len = self.inner.len();
357 for (i, (constraint, _)) in self.inner.values().enumerate() {
358 fmt::Display::fmt(constraint, formatter)?;
359 if i + 1 < len {
360 formatter.write_str(" + ")?;
361 }
362 }
363 Ok(())
364 }
365}
366
367impl<Prim: PrimitiveType> ConstraintSet<Prim> {
368 pub fn new() -> Self {
370 Self {
371 inner: HashMap::new(),
372 }
373 }
374
375 pub fn just(constraint: impl Constraint<Prim>) -> Self {
377 let mut this = Self::new();
378 this.insert(constraint);
379 this
380 }
381
382 pub fn is_empty(&self) -> bool {
384 self.inner.is_empty()
385 }
386
387 fn contains(&self, constraint: &impl Constraint<Prim>) -> bool {
388 self.inner.contains_key(&constraint.to_string())
389 }
390
391 pub fn insert(&mut self, constraint: impl Constraint<Prim>) {
393 self.inner
394 .insert(constraint.to_string(), (Box::new(constraint), false));
395 }
396
397 pub fn insert_object_safe(&mut self, constraint: impl ObjectSafeConstraint<Prim>) {
399 self.inner
400 .insert(constraint.to_string(), (Box::new(constraint), true));
401 }
402
403 pub(crate) fn insert_boxed(&mut self, constraint: Box<dyn Constraint<Prim>>) {
405 self.inner
406 .insert(constraint.to_string(), (constraint, false));
407 }
408
409 pub(crate) fn get_by_name(&self, name: &str) -> Option<(&dyn Constraint<Prim>, bool)> {
411 self.inner
412 .get(name)
413 .map(|(constraint, is_object_safe)| (constraint.as_ref(), *is_object_safe))
414 }
415
416 pub(crate) fn apply_all(
418 &self,
419 ty: &Type<Prim>,
420 substitutions: &mut Substitutions<Prim>,
421 mut errors: OpErrors<'_, Prim>,
422 ) {
423 for (constraint, _) in self.inner.values() {
424 constraint
425 .visitor(substitutions, errors.by_ref())
426 .visit_type(ty);
427 }
428 }
429
430 pub(crate) fn apply_all_to_object(
432 &self,
433 object: &Object<Prim>,
434 substitutions: &mut Substitutions<Prim>,
435 mut errors: OpErrors<'_, Prim>,
436 ) {
437 for (constraint, _) in self.inner.values() {
438 constraint
439 .visitor(substitutions, errors.by_ref())
440 .visit_object(object);
441 }
442 }
443}
444
445#[derive(Debug, Clone, PartialEq)]
447pub(crate) struct CompleteConstraints<Prim: PrimitiveType> {
448 pub simple: ConstraintSet<Prim>,
449 pub object: Option<Object<Prim>>,
451}
452
453impl<Prim: PrimitiveType> Default for CompleteConstraints<Prim> {
454 fn default() -> Self {
455 Self {
456 simple: ConstraintSet::new(),
457 object: None,
458 }
459 }
460}
461
462impl<Prim: PrimitiveType> From<ConstraintSet<Prim>> for CompleteConstraints<Prim> {
463 fn from(constraints: ConstraintSet<Prim>) -> Self {
464 Self {
465 simple: constraints,
466 object: None,
467 }
468 }
469}
470
471impl<Prim: PrimitiveType> From<Object<Prim>> for CompleteConstraints<Prim> {
472 fn from(object: Object<Prim>) -> Self {
473 Self {
474 simple: ConstraintSet::default(),
475 object: Some(object),
476 }
477 }
478}
479
480impl<Prim: PrimitiveType> fmt::Display for CompleteConstraints<Prim> {
481 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
482 match (&self.object, self.simple.is_empty()) {
483 (Some(object), false) => write!(formatter, "{} + {}", object, self.simple),
484 (Some(object), true) => fmt::Display::fmt(object, formatter),
485 (None, _) => fmt::Display::fmt(&self.simple, formatter),
486 }
487 }
488}
489
490impl<Prim: PrimitiveType> CompleteConstraints<Prim> {
491 pub fn is_empty(&self) -> bool {
493 self.object.is_none() && self.simple.is_empty()
494 }
495
496 pub fn insert(
498 &mut self,
499 constraint: impl Constraint<Prim>,
500 substitutions: &mut Substitutions<Prim>,
501 errors: OpErrors<'_, Prim>,
502 ) {
503 self.simple.insert(constraint);
504 self.check_object_consistency(substitutions, errors);
505 }
506
507 pub fn apply_all(
509 &self,
510 ty: &Type<Prim>,
511 substitutions: &mut Substitutions<Prim>,
512 mut errors: OpErrors<'_, Prim>,
513 ) {
514 self.simple.apply_all(ty, substitutions, errors.by_ref());
515 if let Some(lhs) = &self.object {
516 lhs.apply_as_constraint(ty, substitutions, errors);
517 }
518 }
519
520 pub fn map_object(self, map: impl FnOnce(&mut Object<Prim>)) -> Self {
522 Self {
523 simple: self.simple,
524 object: self.object.map(|mut object| {
525 map(&mut object);
526 object
527 }),
528 }
529 }
530
531 pub fn insert_obj_constraint(
533 &mut self,
534 object: Object<Prim>,
535 substitutions: &mut Substitutions<Prim>,
536 mut errors: OpErrors<'_, Prim>,
537 ) {
538 if let Some(existing_object) = &mut self.object {
539 existing_object.extend_from(object, substitutions, errors.by_ref());
540 } else {
541 self.object = Some(object);
542 }
543 self.check_object_consistency(substitutions, errors);
544 }
545
546 fn check_object_consistency(
547 &self,
548 substitutions: &mut Substitutions<Prim>,
549 errors: OpErrors<'_, Prim>,
550 ) {
551 if let Some(object) = &self.object {
552 self.simple
553 .apply_all_to_object(&object, substitutions, errors);
554 }
555 }
556}