mago_codex/ttype/template/
variance.rs

1use serde::Deserialize;
2use serde::Serialize;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
5pub enum Variance {
6    Invariant,
7    Covariant,
8    Contravariant,
9}
10
11impl Variance {
12    #[inline]
13    pub const fn is_invariant(&self) -> bool {
14        matches!(self, Variance::Invariant)
15    }
16
17    #[inline]
18    pub const fn is_covariant(&self) -> bool {
19        matches!(self, Variance::Covariant)
20    }
21
22    #[inline]
23    pub const fn is_contravariant(&self) -> bool {
24        matches!(self, Variance::Contravariant)
25    }
26
27    #[inline]
28    pub const fn is_readonly(&self) -> bool {
29        matches!(self, Variance::Covariant | Variance::Invariant)
30    }
31
32    /// Combines an outer variance context with an inner variance context.
33    ///
34    /// This is used when resolving nested templates, e.g., `Outer<Inner<T>>`.
35    /// The variance of `T` relative to the outermost context depends on both
36    /// the variance of `T` within `Inner` and the variance of `Inner` within `Outer`.
37    ///
38    /// Rules:
39    ///
40    /// - Anything combined with Invariant results in Invariant.
41    /// - Covariant + Covariant = Covariant
42    /// - Contravariant + Contravariant = Covariant
43    /// - Covariant + Contravariant = Contravariant
44    /// - Contravariant + Covariant = Contravariant
45    #[inline]
46    pub const fn combine(outer_variance: Self, inner_variance: Self) -> Self {
47        match (outer_variance, inner_variance) {
48            // If either is invariant, the result is invariant
49            (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
50            // Co + Co = Co
51            (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
52            // Contra + Contra = Co (double negative flips back)
53            (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
54            // Co + Contra = Contra
55            (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
56            // Contra + Co = Contra
57            (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
58        }
59    }
60}
61
62impl std::fmt::Display for Variance {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Variance::Invariant => write!(f, "invariant"),
66            Variance::Covariant => write!(f, "covariant"),
67            Variance::Contravariant => write!(f, "contravariant"),
68        }
69    }
70}