Skip to main content

rib/interpreter/
literal.rs

1use crate::wit_type::WitType;
2use crate::{IntoValueAndType, Value, ValueAndType};
3use std::cmp::Ordering;
4use std::fmt::Display;
5
6pub trait GetLiteralValue {
7    fn get_literal(&self) -> Option<LiteralValue>;
8}
9
10impl GetLiteralValue for ValueAndType {
11    fn get_literal(&self) -> Option<LiteralValue> {
12        match self {
13            ValueAndType {
14                value: Value::String(value),
15                ..
16            } => Some(LiteralValue::String(value.clone())),
17            ValueAndType {
18                value: Value::Char(code_point),
19                ..
20            } => char::from_u32(*code_point as u32)
21                .map(|c| c.to_string())
22                .map(LiteralValue::String),
23            ValueAndType {
24                value: Value::Bool(value),
25                ..
26            } => Some(LiteralValue::Bool(*value)),
27            ValueAndType {
28                value: Value::Enum(idx),
29                typ: WitType::Enum(typ),
30            } => {
31                // An enum can be turned into a simple literal and can be part of string concatenations
32                Some(LiteralValue::String(typ.cases[*idx as usize].clone()))
33            }
34            ValueAndType {
35                value:
36                    Value::Variant {
37                        case_idx,
38                        case_value,
39                    },
40                typ: WitType::Variant(typ),
41            } => {
42                // A no arg variant can be turned into a simple literal and can be part of string concatenations
43                if case_value.is_none() {
44                    Some(LiteralValue::String(
45                        typ.cases[*case_idx as usize].name.clone(),
46                    ))
47                } else {
48                    None
49                }
50            }
51            other => internal::get_numeric_value(other).map(LiteralValue::Num),
52        }
53    }
54}
55
56#[derive(Clone, Debug, PartialEq, PartialOrd)]
57pub enum LiteralValue {
58    Num(CoercedNumericValue),
59    String(String),
60    Bool(bool),
61}
62
63impl LiteralValue {
64    pub fn get_bool(&self) -> Option<bool> {
65        match self {
66            LiteralValue::Bool(value) => Some(*value),
67            _ => None,
68        }
69    }
70
71    pub fn get_number(&self) -> Option<CoercedNumericValue> {
72        match self {
73            LiteralValue::Num(num) => Some(num.clone()),
74            _ => None,
75        }
76    }
77
78    pub fn as_string(&self) -> String {
79        match self {
80            LiteralValue::Num(number) => number.to_string(),
81            LiteralValue::String(value) => value.clone(),
82            LiteralValue::Bool(value) => value.to_string(),
83        }
84    }
85}
86
87impl From<String> for LiteralValue {
88    fn from(value: String) -> Self {
89        if let Ok(u64) = value.parse::<u64>() {
90            LiteralValue::Num(CoercedNumericValue::PosInt(u64))
91        } else if let Ok(i64_value) = value.parse::<i64>() {
92            LiteralValue::Num(CoercedNumericValue::NegInt(i64_value))
93        } else if let Ok(f64_value) = value.parse::<f64>() {
94            LiteralValue::Num(CoercedNumericValue::Float(f64_value))
95        } else if let Ok(bool) = value.parse::<bool>() {
96            LiteralValue::Bool(bool)
97        } else {
98            LiteralValue::String(value.to_string())
99        }
100    }
101}
102
103// A coerced representation of numeric wasm types, simplifying finer-grained TypeAnnotatedValue types into u64, i64, and f64.
104#[derive(Clone, Debug)]
105pub enum CoercedNumericValue {
106    PosInt(u64),
107    NegInt(i64),
108    Float(f64),
109}
110
111impl CoercedNumericValue {
112    pub fn is_zero(&self) -> bool {
113        match self {
114            CoercedNumericValue::PosInt(val) => *val == 0,
115            CoercedNumericValue::NegInt(val) => *val == 0,
116            CoercedNumericValue::Float(val) => *val == 0.0,
117        }
118    }
119
120    pub fn cast_to(&self, analysed_type: &WitType) -> Option<ValueAndType> {
121        match self {
122            CoercedNumericValue::PosInt(number) => {
123                let num = *number;
124
125                match analysed_type {
126                    WitType::U8(_) if num <= u8::MAX as u64 => {
127                        Some((num as u8).into_value_and_type())
128                    }
129                    WitType::U16(_) if num <= u16::MAX as u64 => {
130                        Some((num as u16).into_value_and_type())
131                    }
132                    WitType::U32(_) if num <= u32::MAX as u64 => {
133                        Some((num as u32).into_value_and_type())
134                    }
135                    WitType::U64(_) => Some(num.into_value_and_type()),
136
137                    WitType::S8(_) if num <= i8::MAX as u64 => {
138                        Some((num as i8).into_value_and_type())
139                    }
140                    WitType::S16(_) if num <= i16::MAX as u64 => {
141                        Some((num as i16).into_value_and_type())
142                    }
143                    WitType::S32(_) if num <= i32::MAX as u64 => {
144                        Some((num as i32).into_value_and_type())
145                    }
146                    WitType::S64(_) if num <= i64::MAX as u64 => {
147                        Some((num as i64).into_value_and_type())
148                    }
149
150                    WitType::F32(_) if num <= f32::MAX as u64 => {
151                        Some((num as f32).into_value_and_type())
152                    }
153                    WitType::F64(_) if num <= f64::MAX as u64 => {
154                        Some((num as f64).into_value_and_type())
155                    }
156
157                    _ => None,
158                }
159            }
160
161            CoercedNumericValue::NegInt(number) => {
162                let num = *number;
163
164                match analysed_type {
165                    WitType::S8(_) if num >= i8::MIN as i64 && num <= i8::MAX as i64 => {
166                        Some((num as i8).into_value_and_type())
167                    }
168                    WitType::S16(_) if num >= i16::MIN as i64 && num <= i16::MAX as i64 => {
169                        Some((num as i16).into_value_and_type())
170                    }
171                    WitType::S32(_) if num >= i32::MIN as i64 && num <= i32::MAX as i64 => {
172                        Some((num as i32).into_value_and_type())
173                    }
174                    WitType::S64(_) => Some(num.into_value_and_type()),
175
176                    // Allow unsigned conversion only if non-negative
177                    WitType::U8(_) if num >= 0 && num <= u8::MAX as i64 => {
178                        Some((num as u8).into_value_and_type())
179                    }
180                    WitType::U16(_) if num >= 0 && num <= u16::MAX as i64 => {
181                        Some((num as u16).into_value_and_type())
182                    }
183                    WitType::U32(_) if num >= 0 && num <= u32::MAX as i64 => {
184                        Some((num as u32).into_value_and_type())
185                    }
186                    WitType::U64(_) if num >= 0 => Some((num as u64).into_value_and_type()),
187
188                    WitType::F32(_) if num >= f32::MIN as i64 && num <= f32::MAX as i64 => {
189                        Some((num as f32).into_value_and_type())
190                    }
191                    WitType::F64(_) if num >= f64::MIN as i64 && num <= f64::MAX as i64 => {
192                        Some((num as f64).into_value_and_type())
193                    }
194
195                    _ => None,
196                }
197            }
198
199            CoercedNumericValue::Float(number) => {
200                let num = *number;
201
202                match analysed_type {
203                    WitType::F64(_) => Some(num.into_value_and_type()),
204
205                    WitType::F32(_)
206                        if num.is_finite() && num >= f32::MIN as f64 && num <= f32::MAX as f64 =>
207                    {
208                        Some((num as f32).into_value_and_type())
209                    }
210
211                    WitType::U64(_)
212                        if num.is_finite()
213                            && num >= 0.0
214                            && num <= u64::MAX as f64
215                            && num.fract() == 0.0 =>
216                    {
217                        Some((num as u64).into_value_and_type())
218                    }
219
220                    WitType::S64(_)
221                        if num.is_finite()
222                            && num >= i64::MIN as f64
223                            && num <= i64::MAX as f64
224                            && num.fract() == 0.0 =>
225                    {
226                        Some((num as i64).into_value_and_type())
227                    }
228
229                    _ => None,
230                }
231            }
232        }
233    }
234}
235
236macro_rules! impl_ops {
237    ($trait:ident, $method:ident, $checked_method:ident) => {
238        impl std::ops::$trait for CoercedNumericValue {
239            type Output = Result<Self, String>;
240
241            fn $method(self, rhs: Self) -> Self::Output {
242                use CoercedNumericValue::*;
243                Ok(match (self, rhs) {
244                    (Float(a), Float(b)) => Float(a.$method(b)),
245                    (Float(a), PosInt(b)) => Float(a.$method(b as f64)),
246                    (Float(a), NegInt(b)) => Float(a.$method(b as f64)),
247                    (PosInt(a), Float(b)) => Float((a as f64).$method(b)),
248                    (NegInt(a), Float(b)) => Float((a as f64).$method(b)),
249                    (PosInt(a), PosInt(b)) => a.$checked_method(b).map(PosInt).ok_or(format!(
250                        "overflow in unsigned operation between {} and {}",
251                        a, b
252                    ))?,
253                    (NegInt(a), NegInt(b)) => a.$checked_method(b).map(NegInt).ok_or(format!(
254                        "overflow in signed operation between {} and {}",
255                        a, b
256                    ))?,
257                    (PosInt(a), NegInt(b)) => (a as i64).$checked_method(b).map(NegInt).ok_or(
258                        format!("overflow in signed operation between {} and {}", a, b),
259                    )?,
260                    (NegInt(a), PosInt(b)) => a.$checked_method(b as i64).map(NegInt).ok_or(
261                        format!("overflow in signed operation between {} and {}", a, b),
262                    )?,
263                })
264            }
265        }
266    };
267}
268
269impl_ops!(Add, add, checked_add);
270impl_ops!(Sub, sub, checked_sub);
271impl_ops!(Mul, mul, checked_mul);
272impl_ops!(Div, div, checked_div);
273
274// Auto-derived PartialOrd fails if types don't match
275// and therefore custom impl.
276impl PartialOrd for CoercedNumericValue {
277    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
278        use CoercedNumericValue::*;
279        match (self, other) {
280            (PosInt(a), PosInt(b)) => a.partial_cmp(b),
281            (NegInt(a), NegInt(b)) => a.partial_cmp(b),
282            (Float(a), Float(b)) => a.partial_cmp(b),
283
284            (PosInt(a), NegInt(b)) => {
285                if let Ok(b_as_u64) = u64::try_from(*b) {
286                    a.partial_cmp(&b_as_u64)
287                } else {
288                    Some(Ordering::Greater) // Positive numbers are greater than negative numbers
289                }
290            }
291
292            (NegInt(a), PosInt(b)) => {
293                if let Ok(a_as_u64) = u64::try_from(*a) {
294                    a_as_u64.partial_cmp(b)
295                } else {
296                    Some(Ordering::Less) // Negative numbers are less than positive numbers
297                }
298            }
299
300            (PosInt(a), Float(b)) => (*a as f64).partial_cmp(b),
301
302            (Float(a), PosInt(b)) => a.partial_cmp(&(*b as f64)),
303
304            (NegInt(a), Float(b)) => (*a as f64).partial_cmp(b),
305
306            (Float(a), NegInt(b)) => a.partial_cmp(&(*b as f64)),
307        }
308    }
309}
310
311// Similarly, auto-derived PartialEq fails if types don't match
312// and therefore custom impl
313// There is a high chance two variables can be inferred S32(1) and U32(1)
314impl PartialEq for CoercedNumericValue {
315    fn eq(&self, other: &Self) -> bool {
316        use CoercedNumericValue::*;
317        match (self, other) {
318            (PosInt(a), PosInt(b)) => a == b,
319            (NegInt(a), NegInt(b)) => a == b,
320            (Float(a), Float(b)) => a == b,
321
322            // Comparing PosInt with NegInt
323            (PosInt(a), NegInt(b)) => {
324                if let Ok(b_as_u64) = u64::try_from(*b) {
325                    a == &b_as_u64
326                } else {
327                    false
328                }
329            }
330
331            // Comparing NegInt with PosInt
332            (NegInt(a), PosInt(b)) => {
333                if let Ok(a_as_u64) = u64::try_from(*a) {
334                    &a_as_u64 == b
335                } else {
336                    false
337                }
338            }
339
340            // Comparing PosInt with Float
341            (PosInt(a), Float(b)) => (*a as f64) == *b,
342
343            // Comparing Float with PosInt
344            (Float(a), PosInt(b)) => *a == (*b as f64),
345
346            // Comparing NegInt with Float
347            (NegInt(a), Float(b)) => (*a as f64) == *b,
348
349            // Comparing Float with NegInt
350            (Float(a), NegInt(b)) => *a == (*b as f64),
351        }
352    }
353}
354
355impl Display for CoercedNumericValue {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        match self {
358            CoercedNumericValue::PosInt(value) => write!(f, "{value}"),
359            CoercedNumericValue::NegInt(value) => write!(f, "{value}"),
360            CoercedNumericValue::Float(value) => write!(f, "{value}"),
361        }
362    }
363}
364
365impl Display for LiteralValue {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        match self {
368            LiteralValue::Num(number) => write!(f, "{number}"),
369            LiteralValue::String(value) => write!(f, "{value}"),
370            LiteralValue::Bool(value) => write!(f, "{value}"),
371        }
372    }
373}
374
375mod internal {
376    use crate::interpreter::literal::CoercedNumericValue;
377    use crate::{Value, ValueAndType};
378
379    pub(crate) fn get_numeric_value(value_and_type: &ValueAndType) -> Option<CoercedNumericValue> {
380        match &value_and_type.value {
381            Value::S8(value) => Some(CoercedNumericValue::NegInt(*value as i64)),
382            Value::S16(value) => Some(CoercedNumericValue::NegInt(*value as i64)),
383            Value::S32(value) => Some(CoercedNumericValue::NegInt(*value as i64)),
384            Value::S64(value) => Some(CoercedNumericValue::NegInt(*value)),
385            Value::U8(value) => Some(CoercedNumericValue::PosInt(*value as u64)),
386            Value::U16(value) => Some(CoercedNumericValue::PosInt(*value as u64)),
387            Value::U32(value) => Some(CoercedNumericValue::PosInt(*value as u64)),
388            Value::U64(value) => Some(CoercedNumericValue::PosInt(*value)),
389            Value::F32(value) => Some(CoercedNumericValue::Float(*value as f64)),
390            Value::F64(value) => Some(CoercedNumericValue::Float(*value)),
391            _ => None,
392        }
393    }
394}