hugr_core/types/
type_param.rs

1//! Type Parameters
2//!
3//! Parameters for [`TypeDef`]s provided by extensions
4//!
5//! [`TypeDef`]: crate::extension::TypeDef
6
7use ordered_float::OrderedFloat;
8#[cfg(test)]
9use proptest_derive::Arbitrary;
10use smallvec::{SmallVec, smallvec};
11use std::iter::FusedIterator;
12use std::num::NonZeroU64;
13use std::sync::Arc;
14use thiserror::Error;
15use tracing::warn;
16
17use super::row_var::MaybeRV;
18use super::{
19    NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer,
20    check_typevar_decl,
21};
22use crate::extension::SignatureError;
23
24/// The upper non-inclusive bound of a [`TypeParam::BoundedNat`]
25// A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid)
26#[derive(
27    Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize,
28)]
29#[display("{}", _0.map(|i|i.to_string()).unwrap_or("-".to_string()))]
30#[cfg_attr(test, derive(Arbitrary))]
31pub struct UpperBound(Option<NonZeroU64>);
32impl UpperBound {
33    fn valid_value(&self, val: u64) -> bool {
34        match (val, self.0) {
35            (0, _) | (_, None) => true,
36            (val, Some(inner)) if NonZeroU64::new(val).unwrap() < inner => true,
37            _ => false,
38        }
39    }
40    fn contains(&self, other: &UpperBound) -> bool {
41        match (self.0, other.0) {
42            (None, _) => true,
43            (Some(b1), Some(b2)) if b1 >= b2 => true,
44            _ => false,
45        }
46    }
47
48    /// Returns the value of the upper bound.
49    #[must_use]
50    pub fn value(&self) -> &Option<NonZeroU64> {
51        &self.0
52    }
53}
54
55/// A [`Term`] that is a static argument to an operation or constructor.
56pub type TypeArg = Term;
57
58/// A [`Term`] that is the static type of an operation or constructor parameter.
59pub type TypeParam = Term;
60
61/// A term in the language of static parameters in HUGR.
62#[derive(
63    Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize,
64)]
65#[non_exhaustive]
66#[serde(
67    from = "crate::types::serialize::TermSer",
68    into = "crate::types::serialize::TermSer"
69)]
70pub enum Term {
71    /// The type of runtime types.
72    #[display("Type{}", match _0 {
73        TypeBound::Linear => String::new(),
74        _ => format!("[{_0}]")
75    })]
76    RuntimeType(TypeBound),
77    /// The type of static data.
78    StaticType,
79    /// The type of static natural numbers up to a given bound.
80    #[display("{}", match _0.value() {
81        Some(v) => format!("BoundedNat[{v}]"),
82        None => "Nat".to_string()
83    })]
84    BoundedNatType(UpperBound),
85    /// The type of static strings. See [`Term::String`].
86    StringType,
87    /// The type of static byte strings. See [`Term::Bytes`].
88    BytesType,
89    /// The type of static floating point numbers. See [`Term::Float`].
90    FloatType,
91    /// The type of static lists of indeterminate size containing terms of the
92    /// specified static type.
93    #[display("ListType[{_0}]")]
94    ListType(Box<Term>),
95    /// The type of static tuples.
96    #[display("TupleType[{_0}]")]
97    TupleType(Box<Term>),
98    /// A runtime type as a term. Instance of [`Term::RuntimeType`].
99    #[display("{_0}")]
100    Runtime(Type),
101    /// A 64bit unsigned integer literal. Instance of [`Term::BoundedNatType`].
102    #[display("{_0}")]
103    BoundedNat(u64),
104    /// UTF-8 encoded string literal. Instance of [`Term::StringType`].
105    #[display("\"{_0}\"")]
106    String(String),
107    /// Byte string literal. Instance of [`Term::BytesType`].
108    #[display("bytes")]
109    Bytes(Arc<[u8]>),
110    /// A 64-bit floating point number. Instance of [`Term::FloatType`].
111    #[display("{}", _0.into_inner())]
112    Float(OrderedFloat<f64>),
113    /// A list of static terms. Instance of [`Term::ListType`].
114    #[display("[{}]", {
115        use itertools::Itertools as _;
116        _0.iter().map(|t|t.to_string()).join(",")
117    })]
118    List(Vec<Term>),
119    /// Instance of [`TypeParam::List`] defined by a sequence of concatenated lists of the same type.
120    #[display("[{}]", {
121        use itertools::Itertools as _;
122        _0.iter().map(|t| format!("... {t}")).join(",")
123    })]
124    ListConcat(Vec<TypeArg>),
125    /// Instance of [`TypeParam::Tuple`] defined by a sequence of elements of varying type.
126    #[display("({})", {
127        use itertools::Itertools as _;
128        _0.iter().map(std::string::ToString::to_string).join(",")
129    })]
130    Tuple(Vec<Term>),
131    /// Instance of [`TypeParam::Tuple`] defined by a sequence of concatenated tuples.
132    #[display("({})", {
133        use itertools::Itertools as _;
134        _0.iter().map(|tuple| format!("... {tuple}")).join(",")
135    })]
136    TupleConcat(Vec<TypeArg>),
137    /// Variable (used in type schemes or inside polymorphic functions),
138    /// but not a runtime type (not even a row variable i.e. list of runtime types)
139    /// - see [`Term::new_var_use`]
140    #[display("{_0}")]
141    Variable(TermVar),
142
143    /// The type of constants for a runtime type.
144    ///
145    /// A constant is a compile time description of how to produce a runtime value.
146    /// The runtime value is constructed when the constant is loaded.
147    ///
148    /// Constants are distinct from the runtime values that they describe. In
149    /// particular, as part of the term language, constants can be freely copied
150    /// or destroyed even when they describe a non-linear runtime value.
151    ConstType(Box<Type>),
152}
153
154impl Term {
155    /// Creates a [`Term::BoundedNatType`] with the maximum bound (`u64::MAX` + 1).
156    #[must_use]
157    pub const fn max_nat_type() -> Self {
158        Self::BoundedNatType(UpperBound(None))
159    }
160
161    /// Creates a [`Term::BoundedNatType`] with the stated upper bound (non-exclusive).
162    #[must_use]
163    pub const fn bounded_nat_type(upper_bound: NonZeroU64) -> Self {
164        Self::BoundedNatType(UpperBound(Some(upper_bound)))
165    }
166
167    /// Creates a new [`Term::List`] given a sequence of its items.
168    pub fn new_list(items: impl IntoIterator<Item = Term>) -> Self {
169        Self::List(items.into_iter().collect())
170    }
171
172    /// Creates a new [`Term::ListType`] given the type of its elements.
173    pub fn new_list_type(elem: impl Into<Term>) -> Self {
174        Self::ListType(Box::new(elem.into()))
175    }
176
177    /// Creates a new [`Term::TupleType`] given the type of its elements.
178    pub fn new_tuple_type(item_types: impl Into<Term>) -> Self {
179        Self::TupleType(Box::new(item_types.into()))
180    }
181
182    /// Creates a new [`Term::ConstType`] from a runtime type.
183    pub fn new_const(ty: impl Into<Type>) -> Self {
184        Self::ConstType(Box::new(ty.into()))
185    }
186
187    /// Checks if this term is a supertype of another.
188    ///
189    /// The subtyping relation applies primarily to terms that represent static
190    /// types. For consistency the relation is extended to a partial order on
191    /// all terms; in particular it is reflexive so that every term (even if it
192    /// is not a static type) is considered a subtype of itself.
193    fn is_supertype(&self, other: &Term) -> bool {
194        match (self, other) {
195            (Term::RuntimeType(b1), Term::RuntimeType(b2)) => b1.contains(*b2),
196            (Term::BoundedNatType(b1), Term::BoundedNatType(b2)) => b1.contains(b2),
197            (Term::StringType, Term::StringType) => true,
198            (Term::StaticType, Term::StaticType) => true,
199            (Term::ListType(e1), Term::ListType(e2)) => e1.is_supertype(e2),
200            (Term::TupleType(es1), Term::TupleType(es2)) => es1.is_supertype(es2),
201            (Term::BytesType, Term::BytesType) => true,
202            (Term::FloatType, Term::FloatType) => true,
203            (Term::Runtime(t1), Term::Runtime(t2)) => t1 == t2,
204            (Term::BoundedNat(n1), Term::BoundedNat(n2)) => n1 == n2,
205            (Term::String(s1), Term::String(s2)) => s1 == s2,
206            (Term::Bytes(v1), Term::Bytes(v2)) => v1 == v2,
207            (Term::Float(f1), Term::Float(f2)) => f1 == f2,
208            (Term::Variable(v1), Term::Variable(v2)) => v1 == v2,
209            (Term::List(es1), Term::List(es2)) => {
210                es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2))
211            }
212            (Term::Tuple(es1), Term::Tuple(es2)) => {
213                es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.is_supertype(e2))
214            }
215            _ => false,
216        }
217    }
218}
219
220impl From<TypeBound> for Term {
221    fn from(bound: TypeBound) -> Self {
222        Self::RuntimeType(bound)
223    }
224}
225
226impl From<UpperBound> for Term {
227    fn from(bound: UpperBound) -> Self {
228        Self::BoundedNatType(bound)
229    }
230}
231
232impl<RV: MaybeRV> From<TypeBase<RV>> for Term {
233    fn from(value: TypeBase<RV>) -> Self {
234        match value.try_into_type() {
235            Ok(ty) => Term::Runtime(ty),
236            Err(RowVariable(idx, bound)) => Term::new_var_use(idx, TypeParam::new_list_type(bound)),
237        }
238    }
239}
240
241impl From<u64> for Term {
242    fn from(n: u64) -> Self {
243        Self::BoundedNat(n)
244    }
245}
246
247impl From<String> for Term {
248    fn from(arg: String) -> Self {
249        Term::String(arg)
250    }
251}
252
253impl From<&str> for Term {
254    fn from(arg: &str) -> Self {
255        Term::String(arg.to_string())
256    }
257}
258
259impl From<Vec<Term>> for Term {
260    fn from(elems: Vec<Term>) -> Self {
261        Self::new_list(elems)
262    }
263}
264
265impl<const N: usize> From<[Term; N]> for Term {
266    fn from(value: [Term; N]) -> Self {
267        Self::new_list(value)
268    }
269}
270
271/// Variable in a [`Term`], that is not a single runtime type (i.e. not a [`Type::new_var_use`]
272/// - it might be a [`Type::new_row_var_use`]).
273#[derive(
274    Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display,
275)]
276#[display("#{idx}")]
277pub struct TermVar {
278    idx: usize,
279    pub(in crate::types) cached_decl: Box<Term>,
280}
281
282impl Term {
283    /// [`Type::UNIT`] as a [`Term::Runtime`]
284    pub const UNIT: Self = Self::Runtime(Type::UNIT);
285
286    /// Makes a `TypeArg` representing a use (occurrence) of the type variable
287    /// with the specified index.
288    /// `decl` must be exactly that with which the variable was declared.
289    #[must_use]
290    pub fn new_var_use(idx: usize, decl: Term) -> Self {
291        match decl {
292            // Note a TypeParam::List of TypeParam::Type *cannot* be represented
293            // as a TypeArg::Type because the latter stores a Type<false> i.e. only a single type,
294            // not a RowVariable.
295            Term::RuntimeType(b) => Type::new_var_use(idx, b).into(),
296            _ => Term::Variable(TermVar {
297                idx,
298                cached_decl: Box::new(decl),
299            }),
300        }
301    }
302
303    /// Creates a new string literal.
304    #[inline]
305    pub fn new_string(str: impl ToString) -> Self {
306        Self::String(str.to_string())
307    }
308
309    /// Creates a new concatenated list.
310    #[inline]
311    pub fn new_list_concat(lists: impl IntoIterator<Item = Self>) -> Self {
312        Self::ListConcat(lists.into_iter().collect())
313    }
314
315    /// Creates a new tuple from its items.
316    #[inline]
317    pub fn new_tuple(items: impl IntoIterator<Item = Self>) -> Self {
318        Self::Tuple(items.into_iter().collect())
319    }
320
321    /// Creates a new concatenated tuple.
322    #[inline]
323    pub fn new_tuple_concat(tuples: impl IntoIterator<Item = Self>) -> Self {
324        Self::TupleConcat(tuples.into_iter().collect())
325    }
326
327    /// Returns an integer if the [`Term`] is a natural number literal.
328    #[must_use]
329    pub fn as_nat(&self) -> Option<u64> {
330        match self {
331            TypeArg::BoundedNat(n) => Some(*n),
332            _ => None,
333        }
334    }
335
336    /// Returns a [`Type`] if the [`Term`] is a runtime type.
337    #[must_use]
338    pub fn as_runtime(&self) -> Option<TypeBase<NoRV>> {
339        match self {
340            TypeArg::Runtime(ty) => Some(ty.clone()),
341            _ => None,
342        }
343    }
344
345    /// Returns a string if the [`Term`] is a string literal.
346    #[must_use]
347    pub fn as_string(&self) -> Option<String> {
348        match self {
349            TypeArg::String(arg) => Some(arg.clone()),
350            _ => None,
351        }
352    }
353
354    /// Much as [`Type::validate`], also checks that the type of any [`TypeArg::Opaque`]
355    /// is valid and closed.
356    pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> {
357        match self {
358            Term::Runtime(ty) => ty.validate(var_decls),
359            Term::List(elems) => {
360                // TODO: Full validation would check that the type of the elements agrees
361                elems.iter().try_for_each(|a| a.validate(var_decls))
362            }
363            Term::Tuple(elems) => elems.iter().try_for_each(|a| a.validate(var_decls)),
364            Term::BoundedNat(_) | Term::String { .. } | Term::Float(_) | Term::Bytes(_) => Ok(()),
365            TypeArg::ListConcat(lists) => {
366                // TODO: Full validation would check that each of the lists is indeed a
367                // list or list variable of the correct types.
368                lists.iter().try_for_each(|a| a.validate(var_decls))
369            }
370            TypeArg::TupleConcat(tuples) => tuples.iter().try_for_each(|a| a.validate(var_decls)),
371            Term::Variable(TermVar { idx, cached_decl }) => {
372                assert!(
373                    !matches!(&**cached_decl, TypeParam::RuntimeType { .. }),
374                    "Malformed TypeArg::Variable {cached_decl} - should be inconstructible"
375                );
376
377                check_typevar_decl(var_decls, *idx, cached_decl)
378            }
379            Term::RuntimeType { .. } => Ok(()),
380            Term::BoundedNatType { .. } => Ok(()),
381            Term::StringType => Ok(()),
382            Term::BytesType => Ok(()),
383            Term::FloatType => Ok(()),
384            Term::ListType(item_type) => item_type.validate(var_decls),
385            Term::TupleType(item_types) => item_types.validate(var_decls),
386            Term::StaticType => Ok(()),
387            Term::ConstType(ty) => ty.validate(var_decls),
388        }
389    }
390
391    pub(crate) fn substitute(&self, t: &Substitution) -> Self {
392        match self {
393            Term::Runtime(ty) => {
394                // RowVariables are represented as Term::Variable
395                ty.substitute1(t).into()
396            }
397            TypeArg::BoundedNat(_) | TypeArg::String(_) | TypeArg::Bytes(_) | TypeArg::Float(_) => {
398                self.clone()
399            } // We do not allow variables as bounds on BoundedNat's
400            TypeArg::List(elems) => {
401                // NOTE: This implements a hack allowing substitutions to
402                // replace `TypeArg::Variable`s representing "row variables"
403                // with a list that is to be spliced into the containing list.
404                // We won't need this code anymore once we stop conflating types
405                // with lists of types.
406
407                fn is_type(type_arg: &TypeArg) -> bool {
408                    match type_arg {
409                        TypeArg::Runtime(_) => true,
410                        TypeArg::Variable(v) => v.bound_if_row_var().is_some(),
411                        _ => false,
412                    }
413                }
414
415                let are_types = elems.first().map(is_type).unwrap_or(false);
416
417                Self::new_list_from_parts(elems.iter().map(|elem| match elem.substitute(t) {
418                    list @ TypeArg::List { .. } if are_types => SeqPart::Splice(list),
419                    list @ TypeArg::ListConcat { .. } if are_types => SeqPart::Splice(list),
420                    elem => SeqPart::Item(elem),
421                }))
422            }
423            TypeArg::ListConcat(lists) => {
424                // When a substitution instantiates spliced list variables, we
425                // may be able to merge the concatenated lists.
426                Self::new_list_from_parts(
427                    lists.iter().map(|list| SeqPart::Splice(list.substitute(t))),
428                )
429            }
430            Term::Tuple(elems) => {
431                Term::Tuple(elems.iter().map(|elem| elem.substitute(t)).collect())
432            }
433            TypeArg::TupleConcat(tuples) => {
434                // When a substitution instantiates spliced tuple variables,
435                // we may be able to merge the concatenated tuples.
436                Self::new_tuple_from_parts(
437                    tuples
438                        .iter()
439                        .map(|tuple| SeqPart::Splice(tuple.substitute(t))),
440                )
441            }
442            TypeArg::Variable(TermVar { idx, cached_decl }) => t.apply_var(*idx, cached_decl),
443            Term::RuntimeType(_) => self.clone(),
444            Term::BoundedNatType(_) => self.clone(),
445            Term::StringType => self.clone(),
446            Term::BytesType => self.clone(),
447            Term::FloatType => self.clone(),
448            Term::ListType(item_type) => Term::new_list_type(item_type.substitute(t)),
449            Term::TupleType(item_types) => Term::new_list_type(item_types.substitute(t)),
450            Term::StaticType => self.clone(),
451            Term::ConstType(ty) => Term::new_const(ty.substitute1(t)),
452        }
453    }
454
455    /// Helper method for [`TypeArg::new_list_from_parts`] and [`TypeArg::new_tuple_from_parts`].
456    fn new_seq_from_parts(
457        parts: impl IntoIterator<Item = SeqPart<Self>>,
458        make_items: impl Fn(Vec<Self>) -> Self,
459        make_concat: impl Fn(Vec<Self>) -> Self,
460    ) -> Self {
461        let mut items = Vec::new();
462        let mut seqs = Vec::new();
463
464        for part in parts {
465            match part {
466                SeqPart::Item(item) => items.push(item),
467                SeqPart::Splice(seq) => {
468                    if !items.is_empty() {
469                        seqs.push(make_items(std::mem::take(&mut items)));
470                    }
471                    seqs.push(seq);
472                }
473            }
474        }
475
476        if seqs.is_empty() {
477            make_items(items)
478        } else if items.is_empty() {
479            make_concat(seqs)
480        } else {
481            seqs.push(make_items(items));
482            make_concat(seqs)
483        }
484    }
485
486    /// Creates a new list from a sequence of [`SeqPart`]s.
487    pub fn new_list_from_parts(parts: impl IntoIterator<Item = SeqPart<Self>>) -> Self {
488        Self::new_seq_from_parts(
489            parts.into_iter().flat_map(ListPartIter::new),
490            TypeArg::List,
491            TypeArg::ListConcat,
492        )
493    }
494
495    /// Iterates over the [`SeqPart`]s of a list.
496    ///
497    /// # Examples
498    ///
499    /// The parts of a closed list are the items of that list wrapped in [`SeqPart::Item`]:
500    ///
501    /// ```
502    /// # use hugr_core::types::type_param::{Term, SeqPart};
503    /// # let a = Term::new_string("a");
504    /// # let b = Term::new_string("b");
505    /// let term = Term::new_list([a.clone(), b.clone()]);
506    ///
507    /// assert_eq!(
508    ///     term.into_list_parts().collect::<Vec<_>>(),
509    ///     vec![SeqPart::Item(a), SeqPart::Item(b)]
510    /// );
511    /// ```
512    ///
513    /// Parts of a concatenated list that are not closed lists are wrapped in [`SeqPart::Splice`]:
514    ///
515    /// ```
516    /// # use hugr_core::types::type_param::{Term, SeqPart};
517    /// # let a = Term::new_string("a");
518    /// # let b = Term::new_string("b");
519    /// # let c = Term::new_string("c");
520    /// let var = Term::new_var_use(0, Term::new_list_type(Term::StringType));
521    /// let term = Term::new_list_concat([
522    ///     Term::new_list([a.clone(), b.clone()]),
523    ///     var.clone(),
524    ///     Term::new_list([c.clone()])
525    ///  ]);
526    ///
527    /// assert_eq!(
528    ///     term.into_list_parts().collect::<Vec<_>>(),
529    ///     vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Splice(var), SeqPart::Item(c)]
530    /// );
531    /// ```
532    ///
533    /// Nested concatenations are traversed recursively:
534    ///
535    /// ```
536    /// # use hugr_core::types::type_param::{Term, SeqPart};
537    /// # let a = Term::new_string("a");
538    /// # let b = Term::new_string("b");
539    /// # let c = Term::new_string("c");
540    /// let term = Term::new_list_concat([
541    ///     Term::new_list_concat([
542    ///         Term::new_list([a.clone()]),
543    ///         Term::new_list([b.clone()])
544    ///     ]),
545    ///     Term::new_list([]),
546    ///     Term::new_list([c.clone()])
547    /// ]);
548    ///
549    /// assert_eq!(
550    ///     term.into_list_parts().collect::<Vec<_>>(),
551    ///     vec![SeqPart::Item(a), SeqPart::Item(b), SeqPart::Item(c)]
552    /// );
553    /// ```
554    ///
555    /// When invoked on a type argument that is not a list, a single
556    /// [`SeqPart::Splice`] is returned that wraps the type argument.
557    /// This is the expected behaviour for type variables that stand for lists.
558    /// This behaviour also allows this method not to fail on ill-typed type arguments.
559    /// ```
560    /// # use hugr_core::types::type_param::{Term, SeqPart};
561    /// let term = Term::new_string("not a list");
562    /// assert_eq!(
563    ///     term.clone().into_list_parts().collect::<Vec<_>>(),
564    ///     vec![SeqPart::Splice(term)]
565    /// );
566    /// ```
567    #[inline]
568    pub fn into_list_parts(self) -> ListPartIter {
569        ListPartIter::new(SeqPart::Splice(self))
570    }
571
572    /// Creates a new tuple from a sequence of [`SeqPart`]s.
573    ///
574    /// Analogous to [`TypeArg::new_list_from_parts`].
575    pub fn new_tuple_from_parts(parts: impl IntoIterator<Item = SeqPart<Self>>) -> Self {
576        Self::new_seq_from_parts(
577            parts.into_iter().flat_map(TuplePartIter::new),
578            TypeArg::Tuple,
579            TypeArg::TupleConcat,
580        )
581    }
582
583    /// Iterates over the [`SeqPart`]s of a tuple.
584    ///
585    /// Analogous to [`TypeArg::into_list_parts`].
586    #[inline]
587    pub fn into_tuple_parts(self) -> TuplePartIter {
588        TuplePartIter::new(SeqPart::Splice(self))
589    }
590}
591
592impl Transformable for Term {
593    fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
594        match self {
595            Term::Runtime(ty) => ty.transform(tr),
596            Term::List(elems) => elems.transform(tr),
597            Term::Tuple(elems) => elems.transform(tr),
598            Term::BoundedNat(_)
599            | Term::String(_)
600            | Term::Variable(_)
601            | Term::Float(_)
602            | Term::Bytes(_) => Ok(false),
603            Term::RuntimeType { .. } => Ok(false),
604            Term::BoundedNatType { .. } => Ok(false),
605            Term::StringType => Ok(false),
606            Term::BytesType => Ok(false),
607            Term::FloatType => Ok(false),
608            Term::ListType(item_type) => item_type.transform(tr),
609            Term::TupleType(item_types) => item_types.transform(tr),
610            Term::StaticType => Ok(false),
611            TypeArg::ListConcat(lists) => lists.transform(tr),
612            TypeArg::TupleConcat(tuples) => tuples.transform(tr),
613            Term::ConstType(ty) => ty.transform(tr),
614        }
615    }
616}
617
618impl TermVar {
619    /// Return the index.
620    #[must_use]
621    pub fn index(&self) -> usize {
622        self.idx
623    }
624
625    /// Determines whether this represents a row variable; if so, returns
626    /// the [`TypeBound`] of the individual types it might stand for.
627    #[must_use]
628    pub fn bound_if_row_var(&self) -> Option<TypeBound> {
629        if let Term::ListType(item_type) = &*self.cached_decl {
630            if let Term::RuntimeType(b) = **item_type {
631                return Some(b);
632            }
633        }
634        None
635    }
636}
637
638/// Checks that a [`Term`] is valid for a given type.
639pub fn check_term_type(term: &Term, type_: &Term) -> Result<(), TermTypeError> {
640    match (term, type_) {
641        (Term::Variable(TermVar { cached_decl, .. }), _) if type_.is_supertype(cached_decl) => {
642            Ok(())
643        }
644        (Term::Runtime(ty), Term::RuntimeType(bound)) if bound.contains(ty.least_upper_bound()) => {
645            Ok(())
646        }
647        (Term::List(elems), Term::ListType(item_type)) => {
648            elems.iter().try_for_each(|term| {
649                // Also allow elements that are RowVars if fitting into a List of Types
650                if let (Term::Variable(v), Term::RuntimeType(param_bound)) = (term, &**item_type) {
651                    if v.bound_if_row_var()
652                        .is_some_and(|arg_bound| param_bound.contains(arg_bound))
653                    {
654                        return Ok(());
655                    }
656                }
657                check_term_type(term, item_type)
658            })
659        }
660        (Term::ListConcat(lists), Term::ListType(item_type)) => lists
661            .iter()
662            .try_for_each(|list| check_term_type(list, item_type)),
663        (TypeArg::Tuple(_) | TypeArg::TupleConcat(_), TypeParam::TupleType(item_types)) => {
664            let term_parts: Vec<_> = term.clone().into_tuple_parts().collect();
665            let type_parts: Vec<_> = item_types.clone().into_list_parts().collect();
666
667            for (term, type_) in term_parts.iter().zip(&type_parts) {
668                match (term, type_) {
669                    (SeqPart::Item(term), SeqPart::Item(type_)) => {
670                        check_term_type(term, type_)?;
671                    }
672                    (_, SeqPart::Splice(_)) | (SeqPart::Splice(_), _) => {
673                        // TODO: Checking tuples with splicing requires more
674                        // sophisticated validation infrastructure to do well.
675                        warn!(
676                            "Validation for open tuples is not implemented yet, succeeding regardless..."
677                        );
678                        return Ok(());
679                    }
680                }
681            }
682
683            if term_parts.len() != type_parts.len() {
684                return Err(TermTypeError::WrongNumberTuple(
685                    term_parts.len(),
686                    type_parts.len(),
687                ));
688            }
689
690            Ok(())
691        }
692        (Term::BoundedNat(val), Term::BoundedNatType(bound)) if bound.valid_value(*val) => Ok(()),
693        (Term::String { .. }, Term::StringType) => Ok(()),
694        (Term::Bytes(_), Term::BytesType) => Ok(()),
695        (Term::Float(_), Term::FloatType) => Ok(()),
696
697        // Static types
698        (Term::StaticType, Term::StaticType) => Ok(()),
699        (Term::StringType, Term::StaticType) => Ok(()),
700        (Term::BytesType, Term::StaticType) => Ok(()),
701        (Term::BoundedNatType { .. }, Term::StaticType) => Ok(()),
702        (Term::FloatType, Term::StaticType) => Ok(()),
703        (Term::ListType { .. }, Term::StaticType) => Ok(()),
704        (Term::TupleType(_), Term::StaticType) => Ok(()),
705        (Term::RuntimeType(_), Term::StaticType) => Ok(()),
706        (Term::ConstType(_), Term::StaticType) => Ok(()),
707
708        _ => Err(TermTypeError::TypeMismatch {
709            term: Box::new(term.clone()),
710            type_: Box::new(type_.clone()),
711        }),
712    }
713}
714
715/// Check a list of [`Term`]s is valid for a list of types.
716pub fn check_term_types(terms: &[Term], types: &[Term]) -> Result<(), TermTypeError> {
717    if terms.len() != types.len() {
718        return Err(TermTypeError::WrongNumberArgs(terms.len(), types.len()));
719    }
720    for (term, type_) in terms.iter().zip(types.iter()) {
721        check_term_type(term, type_)?;
722    }
723    Ok(())
724}
725
726/// Errors that can occur when checking that a [`Term`] has an expected type.
727#[derive(Clone, Debug, PartialEq, Eq, Error)]
728#[non_exhaustive]
729pub enum TermTypeError {
730    #[allow(missing_docs)]
731    /// For now, general case of a term not fitting a type.
732    /// We'll have more cases when we allow general Containers.
733    // TODO It may become possible to combine this with ConstTypeError.
734    #[error("Term {term} does not fit declared type {type_}")]
735    TypeMismatch { term: Box<Term>, type_: Box<Term> },
736    /// Wrong number of type arguments (actual vs expected).
737    // For now this only happens at the top level (TypeArgs of op/type vs TypeParams of Op/TypeDef).
738    // However in the future it may be applicable to e.g. contents of Tuples too.
739    #[error("Wrong number of type arguments: {0} vs expected {1} declared type parameters")]
740    WrongNumberArgs(usize, usize),
741
742    /// Wrong number of type arguments in tuple (actual vs expected).
743    #[error(
744        "Wrong number of type arguments to tuple parameter: {0} vs expected {1} declared type parameters"
745    )]
746    WrongNumberTuple(usize, usize),
747    /// Opaque value type check error.
748    #[error("Opaque type argument does not fit declared parameter type: {0}")]
749    OpaqueTypeMismatch(#[from] crate::types::CustomCheckFailure),
750    /// Invalid value
751    #[error("Invalid value of type argument")]
752    InvalidValue(Box<TypeArg>),
753}
754
755/// Part of a sequence.
756#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
757pub enum SeqPart<T> {
758    /// An individual item in the sequence.
759    Item(T),
760    /// A subsequence that is spliced into the parent sequence.
761    Splice(T),
762}
763
764/// Iterator created by [`TypeArg::into_list_parts`].
765#[derive(Debug, Clone)]
766pub struct ListPartIter {
767    parts: SmallVec<[SeqPart<TypeArg>; 1]>,
768}
769
770impl ListPartIter {
771    #[inline]
772    fn new(part: SeqPart<TypeArg>) -> Self {
773        Self {
774            parts: smallvec![part],
775        }
776    }
777}
778
779impl Iterator for ListPartIter {
780    type Item = SeqPart<TypeArg>;
781
782    fn next(&mut self) -> Option<Self::Item> {
783        loop {
784            match self.parts.pop()? {
785                SeqPart::Splice(TypeArg::List(elems)) => self
786                    .parts
787                    .extend(elems.into_iter().rev().map(SeqPart::Item)),
788                SeqPart::Splice(TypeArg::ListConcat(lists)) => self
789                    .parts
790                    .extend(lists.into_iter().rev().map(SeqPart::Splice)),
791                part => return Some(part),
792            }
793        }
794    }
795}
796
797impl FusedIterator for ListPartIter {}
798
799/// Iterator created by [`TypeArg::into_tuple_parts`].
800#[derive(Debug, Clone)]
801pub struct TuplePartIter {
802    parts: SmallVec<[SeqPart<TypeArg>; 1]>,
803}
804
805impl TuplePartIter {
806    #[inline]
807    fn new(part: SeqPart<TypeArg>) -> Self {
808        Self {
809            parts: smallvec![part],
810        }
811    }
812}
813
814impl Iterator for TuplePartIter {
815    type Item = SeqPart<TypeArg>;
816
817    fn next(&mut self) -> Option<Self::Item> {
818        loop {
819            match self.parts.pop()? {
820                SeqPart::Splice(TypeArg::Tuple(elems)) => self
821                    .parts
822                    .extend(elems.into_iter().rev().map(SeqPart::Item)),
823                SeqPart::Splice(TypeArg::TupleConcat(tuples)) => self
824                    .parts
825                    .extend(tuples.into_iter().rev().map(SeqPart::Splice)),
826                part => return Some(part),
827            }
828        }
829    }
830}
831
832impl FusedIterator for TuplePartIter {}
833
834#[cfg(test)]
835mod test {
836    use itertools::Itertools;
837
838    use super::{Substitution, TypeArg, TypeParam, check_term_type};
839    use crate::extension::prelude::{bool_t, usize_t};
840    use crate::types::Term;
841    use crate::types::type_param::SeqPart;
842    use crate::types::{TypeBound, TypeRV, type_param::TermTypeError};
843
844    #[test]
845    fn new_list_from_parts_items() {
846        let a = TypeArg::new_string("a");
847        let b = TypeArg::new_string("b");
848
849        let parts = [SeqPart::Item(a.clone()), SeqPart::Item(b.clone())];
850        let items = [a, b];
851
852        assert_eq!(
853            TypeArg::new_list_from_parts(parts.clone()),
854            TypeArg::new_list(items.clone())
855        );
856
857        assert_eq!(
858            TypeArg::new_tuple_from_parts(parts),
859            TypeArg::new_tuple(items)
860        );
861    }
862
863    #[test]
864    fn new_list_from_parts_flatten() {
865        let a = Term::new_string("a");
866        let b = Term::new_string("b");
867        let c = Term::new_string("c");
868        let d = Term::new_string("d");
869        let var = Term::new_var_use(0, Term::new_list_type(Term::StringType));
870        let parts = [
871            SeqPart::Splice(Term::new_list([a.clone(), b.clone()])),
872            SeqPart::Splice(Term::new_list_concat([Term::new_list([c.clone()])])),
873            SeqPart::Item(d.clone()),
874            SeqPart::Splice(var.clone()),
875        ];
876        assert_eq!(
877            Term::new_list_from_parts(parts),
878            Term::new_list_concat([Term::new_list([a, b, c, d]), var])
879        );
880    }
881
882    #[test]
883    fn new_tuple_from_parts_flatten() {
884        let a = Term::new_string("a");
885        let b = Term::new_string("b");
886        let c = Term::new_string("c");
887        let d = Term::new_string("d");
888        let var = Term::new_var_use(0, Term::new_tuple([Term::StringType]));
889        let parts = [
890            SeqPart::Splice(Term::new_tuple([a.clone(), b.clone()])),
891            SeqPart::Splice(Term::new_tuple_concat([Term::new_tuple([c.clone()])])),
892            SeqPart::Item(d.clone()),
893            SeqPart::Splice(var.clone()),
894        ];
895        assert_eq!(
896            Term::new_tuple_from_parts(parts),
897            Term::new_tuple_concat([Term::new_tuple([a, b, c, d]), var])
898        );
899    }
900
901    #[test]
902    fn type_arg_fits_param() {
903        let rowvar = TypeRV::new_row_var_use;
904        fn check(arg: impl Into<TypeArg>, param: &TypeParam) -> Result<(), TermTypeError> {
905            check_term_type(&arg.into(), param)
906        }
907        fn check_seq<T: Clone + Into<TypeArg>>(
908            args: &[T],
909            param: &TypeParam,
910        ) -> Result<(), TermTypeError> {
911            let arg = args.iter().cloned().map_into().collect_vec().into();
912            check_term_type(&arg, param)
913        }
914        // Simple cases: a Term::Type is a Term::RuntimeType but singleton sequences are lists
915        check(usize_t(), &TypeBound::Copyable.into()).unwrap();
916        let seq_param = TypeParam::new_list_type(TypeBound::Copyable);
917        check(usize_t(), &seq_param).unwrap_err();
918        check_seq(&[usize_t()], &TypeBound::Linear.into()).unwrap_err();
919
920        // Into a list of type, we can fit a single row var
921        check(rowvar(0, TypeBound::Copyable), &seq_param).unwrap();
922        // or a list of (types or row vars)
923        check(vec![], &seq_param).unwrap();
924        check_seq(&[rowvar(0, TypeBound::Copyable)], &seq_param).unwrap();
925        check_seq(
926            &[
927                rowvar(1, TypeBound::Linear),
928                usize_t().into(),
929                rowvar(0, TypeBound::Copyable),
930            ],
931            &TypeParam::new_list_type(TypeBound::Linear),
932        )
933        .unwrap();
934        // Next one fails because a list of Eq is required
935        check_seq(
936            &[
937                rowvar(1, TypeBound::Linear),
938                usize_t().into(),
939                rowvar(0, TypeBound::Copyable),
940            ],
941            &seq_param,
942        )
943        .unwrap_err();
944        // seq of seq of types is not allowed
945        check(
946            vec![usize_t().into(), vec![usize_t().into()].into()],
947            &seq_param,
948        )
949        .unwrap_err();
950
951        // Similar for nats (but no equivalent of fancy row vars)
952        check(5, &TypeParam::max_nat_type()).unwrap();
953        check_seq(&[5], &TypeParam::max_nat_type()).unwrap_err();
954        let list_of_nat = TypeParam::new_list_type(TypeParam::max_nat_type());
955        check(5, &list_of_nat).unwrap_err();
956        check_seq(&[5], &list_of_nat).unwrap();
957        check(TypeArg::new_var_use(0, list_of_nat.clone()), &list_of_nat).unwrap();
958        // But no equivalent of row vars - can't append a nat onto a list-in-a-var:
959        check(
960            vec![5.into(), TypeArg::new_var_use(0, list_of_nat.clone())],
961            &list_of_nat,
962        )
963        .unwrap_err();
964
965        // `Term::TupleType` requires a `Term::Tuple` of the same number of elems
966        let usize_and_ty =
967            TypeParam::new_tuple_type([TypeParam::max_nat_type(), TypeBound::Copyable.into()]);
968        check(
969            TypeArg::Tuple(vec![5.into(), usize_t().into()]),
970            &usize_and_ty,
971        )
972        .unwrap();
973        check(
974            TypeArg::Tuple(vec![usize_t().into(), 5.into()]),
975            &usize_and_ty,
976        )
977        .unwrap_err(); // Wrong way around
978        let two_types = TypeParam::new_tuple_type(Term::new_list([
979            TypeBound::Linear.into(),
980            TypeBound::Linear.into(),
981        ]));
982        check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap();
983        // not a Row Var which could have any number of elems
984        check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err();
985    }
986
987    #[test]
988    fn type_arg_subst_row() {
989        let row_param = Term::new_list_type(TypeBound::Copyable);
990        let row_arg: Term = vec![bool_t().into(), Term::UNIT].into();
991        check_term_type(&row_arg, &row_param).unwrap();
992
993        // Now say a row variable referring to *that* row was used
994        // to instantiate an outer "row parameter" (list of type).
995        let outer_param = Term::new_list_type(TypeBound::Linear);
996        let outer_arg = Term::new_list([
997            TypeRV::new_row_var_use(0, TypeBound::Copyable).into(),
998            usize_t().into(),
999        ]);
1000        check_term_type(&outer_arg, &outer_param).unwrap();
1001
1002        let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg]));
1003        assert_eq!(
1004            outer_arg2,
1005            vec![bool_t().into(), Term::UNIT, usize_t().into()].into()
1006        );
1007
1008        // Of course this is still valid (as substitution is guaranteed to preserve validity)
1009        check_term_type(&outer_arg2, &outer_param).unwrap();
1010    }
1011
1012    #[test]
1013    fn subst_list_list() {
1014        let outer_param = Term::new_list_type(Term::new_list_type(TypeBound::Linear));
1015        let row_var_decl = Term::new_list_type(TypeBound::Copyable);
1016        let row_var_use = Term::new_var_use(0, row_var_decl.clone());
1017        let good_arg = Term::new_list([
1018            // The row variables here refer to `row_var_decl` above
1019            vec![usize_t().into()].into(),
1020            row_var_use.clone(),
1021            vec![row_var_use, usize_t().into()].into(),
1022        ]);
1023        check_term_type(&good_arg, &outer_param).unwrap();
1024
1025        // Outer list cannot include single types:
1026        let Term::List(mut elems) = good_arg.clone() else {
1027            panic!()
1028        };
1029        elems.push(usize_t().into());
1030        assert_eq!(
1031            check_term_type(&Term::new_list(elems), &outer_param),
1032            Err(TermTypeError::TypeMismatch {
1033                term: Box::new(usize_t().into()),
1034                // The error reports the type expected for each element of the list:
1035                type_: Box::new(TypeParam::new_list_type(TypeBound::Linear))
1036            })
1037        );
1038
1039        // Now substitute a list of two types for that row-variable
1040        let row_var_arg = vec![usize_t().into(), bool_t().into()].into();
1041        check_term_type(&row_var_arg, &row_var_decl).unwrap();
1042        let subst_arg = good_arg.substitute(&Substitution(std::slice::from_ref(&row_var_arg)));
1043        check_term_type(&subst_arg, &outer_param).unwrap(); // invariance of substitution
1044        assert_eq!(
1045            subst_arg,
1046            Term::new_list([
1047                Term::new_list([usize_t().into()]),
1048                row_var_arg,
1049                Term::new_list([usize_t().into(), bool_t().into(), usize_t().into()])
1050            ])
1051        );
1052    }
1053
1054    #[test]
1055    fn bytes_json_roundtrip() {
1056        let bytes_arg = Term::Bytes(vec![0, 1, 2, 3, 255, 254, 253, 252].into());
1057        let serialized = serde_json::to_string(&bytes_arg).unwrap();
1058        let deserialized: Term = serde_json::from_str(&serialized).unwrap();
1059        assert_eq!(deserialized, bytes_arg);
1060    }
1061
1062    mod proptest {
1063
1064        use proptest::prelude::*;
1065
1066        use super::super::{TermVar, UpperBound};
1067        use crate::proptest::RecursionDepth;
1068        use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param};
1069
1070        impl Arbitrary for TermVar {
1071            type Parameters = RecursionDepth;
1072            type Strategy = BoxedStrategy<Self>;
1073            fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy {
1074                (any::<usize>(), any_serde_type_param(depth))
1075                    .prop_map(|(idx, cached_decl)| Self {
1076                        idx,
1077                        cached_decl: Box::new(cached_decl),
1078                    })
1079                    .boxed()
1080            }
1081        }
1082
1083        impl Arbitrary for Term {
1084            type Parameters = RecursionDepth;
1085            type Strategy = BoxedStrategy<Self>;
1086            fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy {
1087                use prop::collection::vec;
1088                use prop::strategy::Union;
1089                let mut strat = Union::new([
1090                    Just(Self::StringType).boxed(),
1091                    Just(Self::BytesType).boxed(),
1092                    Just(Self::FloatType).boxed(),
1093                    Just(Self::StringType).boxed(),
1094                    any::<TypeBound>().prop_map(Self::from).boxed(),
1095                    any::<UpperBound>().prop_map(Self::from).boxed(),
1096                    any::<u64>().prop_map(Self::from).boxed(),
1097                    any::<String>().prop_map(Self::from).boxed(),
1098                    any::<Vec<u8>>()
1099                        .prop_map(|bytes| Self::Bytes(bytes.into()))
1100                        .boxed(),
1101                    any::<f64>()
1102                        .prop_map(|value| Self::Float(value.into()))
1103                        .boxed(),
1104                    any_with::<Type>(depth).prop_map(Self::from).boxed(),
1105                ]);
1106                if !depth.leaf() {
1107                    // we descend here because we these constructors contain Terms
1108                    strat = strat
1109                        .or(
1110                            // TODO this is a bit dodgy, TypeArgVariables are supposed
1111                            // to be constructed from TypeArg::new_var_use. We are only
1112                            // using this instance for serialization now, but if we want
1113                            // to generate valid TypeArgs this will need to change.
1114                            any_with::<TermVar>(depth.descend())
1115                                .prop_map(Self::Variable)
1116                                .boxed(),
1117                        )
1118                        .or(any_with::<Self>(depth.descend())
1119                            .prop_map(Self::new_list_type)
1120                            .boxed())
1121                        .or(any_with::<Self>(depth.descend())
1122                            .prop_map(Self::new_tuple_type)
1123                            .boxed())
1124                        .or(vec(any_with::<Self>(depth.descend()), 0..3)
1125                            .prop_map(Self::new_list)
1126                            .boxed());
1127                }
1128
1129                strat.boxed()
1130            }
1131        }
1132
1133        proptest! {
1134            #[test]
1135            fn term_contains_itself(term: Term) {
1136                assert!(term.is_supertype(&term));
1137            }
1138        }
1139    }
1140}