Skip to main content

runmat_runtime/builtins/common/
arg_tokens.rs

1use runmat_builtins::{LiteralValue, ResolveContext, Value};
2
3#[derive(Clone, Debug, PartialEq)]
4pub enum ArgToken {
5    Number(f64),
6    Bool(bool),
7    String(String),
8    Vector(Vec<ArgToken>),
9    Unknown,
10}
11
12pub fn tokens_from_values(args: &[Value]) -> Vec<ArgToken> {
13    args.iter().map(token_from_value).collect()
14}
15
16pub fn tokens_from_context(ctx: &ResolveContext) -> Vec<ArgToken> {
17    ctx.literal_args.iter().map(token_from_literal).collect()
18}
19
20fn token_from_literal(value: &LiteralValue) -> ArgToken {
21    match value {
22        LiteralValue::Number(num) => ArgToken::Number(*num),
23        LiteralValue::Bool(value) => ArgToken::Bool(*value),
24        LiteralValue::String(text) => ArgToken::String(text.to_ascii_lowercase()),
25        LiteralValue::Vector(values) => {
26            ArgToken::Vector(values.iter().map(token_from_literal).collect())
27        }
28        LiteralValue::Unknown => ArgToken::Unknown,
29    }
30}
31
32fn token_from_value(value: &Value) -> ArgToken {
33    match value {
34        Value::Num(num) => ArgToken::Number(*num),
35        Value::Int(value) => ArgToken::Number(value.to_f64()),
36        Value::Bool(value) => ArgToken::Bool(*value),
37        Value::String(text) => ArgToken::String(text.to_ascii_lowercase()),
38        Value::StringArray(arr) if arr.data.len() == 1 => {
39            ArgToken::String(arr.data[0].to_ascii_lowercase())
40        }
41        Value::CharArray(arr) if arr.rows == 1 => {
42            let text: String = arr.data.iter().collect();
43            ArgToken::String(text.to_ascii_lowercase())
44        }
45        Value::Tensor(tensor) => token_from_tensor(&tensor.data, &tensor.shape),
46        Value::LogicalArray(arr) => token_from_logical(&arr.data, &arr.shape),
47        _ => ArgToken::Unknown,
48    }
49}
50
51fn token_from_tensor(data: &[f64], shape: &[usize]) -> ArgToken {
52    if data.len() == 1 {
53        return ArgToken::Number(data[0]);
54    }
55    if is_vector_shape(shape) {
56        return ArgToken::Vector(data.iter().copied().map(ArgToken::Number).collect());
57    }
58    ArgToken::Unknown
59}
60
61fn token_from_logical(data: &[u8], shape: &[usize]) -> ArgToken {
62    if data.len() == 1 {
63        return ArgToken::Bool(data[0] != 0);
64    }
65    if is_vector_shape(shape) {
66        return ArgToken::Vector(data.iter().map(|b| ArgToken::Bool(*b != 0)).collect());
67    }
68    ArgToken::Unknown
69}
70
71fn is_vector_shape(shape: &[usize]) -> bool {
72    if shape.is_empty() {
73        return false;
74    }
75    if shape.len() == 1 {
76        return true;
77    }
78    if shape.len() == 2 {
79        return shape[0] == 1 || shape[1] == 1;
80    }
81    false
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use runmat_builtins::{IntValue, LiteralValue, ResolveContext};
88
89    #[test]
90    fn tokens_from_context_lowercases_strings() {
91        let ctx = ResolveContext::new(vec![LiteralValue::String("OmItNaN".to_string())]);
92        assert_eq!(
93            tokens_from_context(&ctx),
94            vec![ArgToken::String("omitnan".to_string())]
95        );
96    }
97
98    #[test]
99    fn tokens_from_context_handles_vectors() {
100        let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
101            LiteralValue::Number(1.0),
102            LiteralValue::Bool(true),
103        ])]);
104        assert_eq!(
105            tokens_from_context(&ctx),
106            vec![ArgToken::Vector(vec![
107                ArgToken::Number(1.0),
108                ArgToken::Bool(true)
109            ])]
110        );
111    }
112
113    #[test]
114    fn tokens_from_values_handles_scalar_inputs() {
115        let args = vec![
116            Value::Num(2.0),
117            Value::Int(IntValue::I32(3)),
118            Value::Bool(true),
119            Value::String("All".to_string()),
120        ];
121        assert_eq!(
122            tokens_from_values(&args),
123            vec![
124                ArgToken::Number(2.0),
125                ArgToken::Number(3.0),
126                ArgToken::Bool(true),
127                ArgToken::String("all".to_string()),
128            ]
129        );
130    }
131
132    #[test]
133    fn tokens_from_values_handles_vector_tensor() {
134        let tensor = runmat_builtins::Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
135        let args = vec![Value::Tensor(tensor)];
136        assert_eq!(
137            tokens_from_values(&args),
138            vec![ArgToken::Vector(vec![
139                ArgToken::Number(1.0),
140                ArgToken::Number(2.0)
141            ])]
142        );
143    }
144}