Skip to main content

mago_codex/ttype/template/
variance.rs

1use serde::Deserialize;
2use serde::Serialize;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
5pub enum Variance {
6    Invariant,
7    Covariant,
8    Contravariant,
9}
10
11impl Variance {
12    #[inline]
13    #[must_use]
14    pub const fn is_invariant(&self) -> bool {
15        matches!(self, Variance::Invariant)
16    }
17
18    #[inline]
19    #[must_use]
20    pub const fn is_covariant(&self) -> bool {
21        matches!(self, Variance::Covariant)
22    }
23
24    #[inline]
25    #[must_use]
26    pub const fn is_contravariant(&self) -> bool {
27        matches!(self, Variance::Contravariant)
28    }
29
30    #[inline]
31    #[must_use]
32    pub const fn is_readonly(&self) -> bool {
33        matches!(self, Variance::Covariant | Variance::Invariant)
34    }
35
36    /// Combines an outer variance context with an inner variance context.
37    ///
38    /// This is used when resolving nested templates, e.g., `Outer<Inner<T>>`.
39    /// The variance of `T` relative to the outermost context depends on both
40    /// the variance of `T` within `Inner` and the variance of `Inner` within `Outer`.
41    ///
42    /// Rules:
43    ///
44    /// - Anything combined with Invariant results in Invariant.
45    /// - Covariant + Covariant = Covariant
46    /// - Contravariant + Contravariant = Covariant
47    /// - Covariant + Contravariant = Contravariant
48    /// - Contravariant + Covariant = Contravariant
49    #[inline]
50    #[must_use]
51    pub const fn combine(outer_variance: Self, inner_variance: Self) -> Self {
52        match (outer_variance, inner_variance) {
53            // If either is invariant, the result is invariant
54            (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
55            // Co + Co = Co
56            (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
57            // Contra + Contra = Co (double negative flips back)
58            (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
59            // Co + Contra = Contra
60            (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
61            // Contra + Co = Contra
62            (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
63        }
64    }
65}
66
67impl std::fmt::Display for Variance {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            Variance::Invariant => write!(f, "invariant"),
71            Variance::Covariant => write!(f, "covariant"),
72            Variance::Contravariant => write!(f, "contravariant"),
73        }
74    }
75}