Skip to main content

luaur_common/records/
variant.rs

1//! Faithful port of Luau's `Variant<Ts...>` — a `std::variant`-like tagged union.
2//! Reference: `luau/Common/include/Luau/Variant.h`. Oracle:
3//! `luau/tests/Variant.test.cpp` + `/tmp/variant_proto.rs` (DefaultCtor, Create,
4//! Emplace, NonPOD copy, Equality, Visit — all pass).
5//!
6//! Rust has no variadic generics, so the C++ variadic `Variant<Ts...>` becomes a
7//! fixed-arity enum family `Variant1<T0> .. Variant7<..>` (the arities Luau
8//! actually instantiates; max is 7). A Rust `enum` *is* a tagged union, so this
9//! is safe (no type-erased storage, no fn-pointer dispatch tables) and idiomatic.
10//!
11//! Mechanical mapping for callers (e.g. the eventual Analysis port):
12//! - `Variant<A, B, C>`            -> `Variant3<A, B, C>`
13//! - `Variant<A,B> x = a;`         -> `Variant2::V0(a)` (no blanket `From<Ti>` —
14//!   Rust coherence forbids it since `T0` could equal `T1`; construct the variant
15//!   directly at the position the type occupies)
16//! - `v.get_if<B>()` (B is pos 1)  -> `v.get_if_1()` / `v.get_if_1_mut()`
17//! - `v.emplace<B>(args)`          -> `v = Variant2::V1(B::from(args))`
18//! - `v.index()`                   -> `v.index()`
19//! - `visit(overloaded{...}, v)`   -> `match v { Variant3::V0(x) => …, … }`
20//!
21//! `==` (C++ `operator==`) and `Default` (C++ `Variant()` -> first alternative)
22//! come from the derives + the generated first-alternative `Default` impl;
23//! `valueless_by_exception()` is always `false`.
24
25/// Generates one `VariantN` enum plus its `index`/`get_if_*`/`Default` API.
26macro_rules! define_variant {
27    (
28        $name:ident < $t0:ident $(, $t:ident)* >
29        = $v0:ident($g0:ident, $g0m:ident)
30        $(, $idx:literal : $v:ident < $ty:ident > ($g:ident, $gm:ident) )*
31    ) => {
32        #[allow(clippy::large_enum_variant)]
33        #[derive(Clone, Debug, PartialEq, Eq, Hash)]
34        pub enum $name<$t0 $(, $t)*> {
35            $v0($t0),
36            $( $v($ty), )*
37        }
38
39        impl<$t0 $(, $t)*> $name<$t0 $(, $t)*> {
40            /// `index()` / `typeId` — the active alternative's position.
41            pub fn index(&self) -> usize {
42                match self {
43                    Self::$v0(_) => 0,
44                    $( Self::$v(_) => $idx, )*
45                }
46            }
47
48            /// Always `false` (this port has no valueless state). Matches the C++
49            /// `valueless_by_exception`.
50            pub fn valueless_by_exception(&self) -> bool {
51                false
52            }
53
54            pub fn $g0(&self) -> ::core::option::Option<&$t0> {
55                match self {
56                    Self::$v0(x) => ::core::option::Option::Some(x),
57                    #[allow(unreachable_patterns)]
58                    _ => ::core::option::Option::None,
59                }
60            }
61            pub fn $g0m(&mut self) -> ::core::option::Option<&mut $t0> {
62                match self {
63                    Self::$v0(x) => ::core::option::Option::Some(x),
64                    #[allow(unreachable_patterns)]
65                    _ => ::core::option::Option::None,
66                }
67            }
68            $(
69                pub fn $g(&self) -> ::core::option::Option<&$ty> {
70                    match self {
71                        Self::$v(x) => ::core::option::Option::Some(x),
72                        #[allow(unreachable_patterns)]
73                        _ => ::core::option::Option::None,
74                    }
75                }
76                pub fn $gm(&mut self) -> ::core::option::Option<&mut $ty> {
77                    match self {
78                        Self::$v(x) => ::core::option::Option::Some(x),
79                        #[allow(unreachable_patterns)]
80                        _ => ::core::option::Option::None,
81                    }
82                }
83            )*
84        }
85
86        // C++ `Variant()` default-constructs the first alternative.
87        impl<$t0: ::core::default::Default $(, $t)*> ::core::default::Default for $name<$t0 $(, $t)*> {
88            fn default() -> Self {
89                Self::$v0(<$t0 as ::core::default::Default>::default())
90            }
91        }
92    };
93}
94
95define_variant!(Variant1<T0> = V0(get_if_0, get_if_0_mut));
96define_variant!(
97    Variant2<T0, T1> = V0(get_if_0, get_if_0_mut),
98    1: V1<T1>(get_if_1, get_if_1_mut)
99);
100define_variant!(
101    Variant3<T0, T1, T2> = V0(get_if_0, get_if_0_mut),
102    1: V1<T1>(get_if_1, get_if_1_mut),
103    2: V2<T2>(get_if_2, get_if_2_mut)
104);
105define_variant!(
106    Variant4<T0, T1, T2, T3> = V0(get_if_0, get_if_0_mut),
107    1: V1<T1>(get_if_1, get_if_1_mut),
108    2: V2<T2>(get_if_2, get_if_2_mut),
109    3: V3<T3>(get_if_3, get_if_3_mut)
110);
111define_variant!(
112    Variant5<T0, T1, T2, T3, T4> = V0(get_if_0, get_if_0_mut),
113    1: V1<T1>(get_if_1, get_if_1_mut),
114    2: V2<T2>(get_if_2, get_if_2_mut),
115    3: V3<T3>(get_if_3, get_if_3_mut),
116    4: V4<T4>(get_if_4, get_if_4_mut)
117);
118define_variant!(
119    Variant6<T0, T1, T2, T3, T4, T5> = V0(get_if_0, get_if_0_mut),
120    1: V1<T1>(get_if_1, get_if_1_mut),
121    2: V2<T2>(get_if_2, get_if_2_mut),
122    3: V3<T3>(get_if_3, get_if_3_mut),
123    4: V4<T4>(get_if_4, get_if_4_mut),
124    5: V5<T5>(get_if_5, get_if_5_mut)
125);
126define_variant!(
127    Variant7<T0, T1, T2, T3, T4, T5, T6> = V0(get_if_0, get_if_0_mut),
128    1: V1<T1>(get_if_1, get_if_1_mut),
129    2: V2<T2>(get_if_2, get_if_2_mut),
130    3: V3<T3>(get_if_3, get_if_3_mut),
131    4: V4<T4>(get_if_4, get_if_4_mut),
132    5: V5<T5>(get_if_5, get_if_5_mut),
133    6: V6<T6>(get_if_6, get_if_6_mut)
134);
135
136#[cfg(test)]
137mod tests {
138    use super::{Variant2, Variant3};
139    use alloc::string::{String, ToString};
140
141    // Mirrors luau/tests/Variant.test.cpp (DefaultCtor / Create / Emplace /
142    // NonPOD / Equality / Visit).
143    #[test]
144    fn variant_behavior() {
145        // DefaultCtor: first alternative, default value.
146        let v: Variant2<i32, String> = Variant2::default();
147        assert_eq!(v.get_if_0(), Some(&0));
148        assert!(v.get_if_1().is_none());
149        assert_eq!(v.index(), 0);
150        assert!(!v.valueless_by_exception());
151
152        // Create + get_if by position.
153        let v1: Variant2<i32, String> = Variant2::V1("hi".to_string());
154        assert_eq!(v1.get_if_1().map(String::as_str), Some("hi"));
155        assert_eq!(v1.index(), 1);
156
157        // Emplace == reassign; NonPOD copy via Clone.
158        let mut m: Variant2<i32, String> = Variant2::V0(5);
159        m = Variant2::V1("x".to_string());
160        let mc = m.clone();
161        assert_eq!(m, mc);
162
163        // Equality: same variant+value; default == V0(0).
164        let a: Variant2<i32, String> = Variant2::V0(0);
165        assert_eq!(a, Variant2::<i32, String>::default());
166        assert_ne!(v1, Variant2::V1("me".to_string()));
167        assert_ne!(v1, Variant2::V0(1));
168
169        // Visit -> match (arity 3).
170        let t: Variant3<i32, bool, String> = Variant3::V2("z".to_string());
171        let rendered = match &t {
172            Variant3::V0(n) => n.to_string(),
173            Variant3::V1(b) => b.to_string(),
174            Variant3::V2(s) => s.clone(),
175        };
176        assert_eq!(rendered, "z");
177    }
178}